2017-07-17 19 views
0

テキストを生成するLSTMに取り組んでいます。以前訓練されたモデルを再利用する際に問題があります。私はリソースとしてtensorflow websiteを使用している間私のコードを分解しました。テンソルフローモデル(LSTM)の保存と復元の問題

ここで私は私のすべての変数を確立します。

ここ
graph = tf.Graph() 

with graph.as_default(): 
    global_step = tf.Variable(0) 

    data = tf.placeholder(tf.float32, [batch_size, len_section, char_size]) 
    labels = tf.placeholder(tf.float32, [batch_size, char_size]) 

    ..... 

    #Reset at the beginning of each test 
    reset_test_state = tf.group(test_output.assign(tf.zeros([1, hidden_nodes])), 
           test_state.assign(tf.zeros([1, hidden_nodes]))) 

    #LSTM 
    test_output, test_state = lstm(test_data, test_output, test_state) 
    test_prediction = tf.nn.softmax(tf.matmul(test_output, w) + b) 

    saver = tf.train.Saver() 

、私は私のモデルを訓練し、30回の反復

with tf.Session(graph = graph) as sess: 
    tf.global_variables_initializer().run() 
    offset = 0 

    for step in range(10000): 

     offset = offset % len(X) 

     if offset <= (len(X) - batch_size): 

      batch_data = X[offset: offset + batch_size] 
      batch_labels = y[offset:offset+batch_size] 
      offset += batch_size 

     else: 
      to_add = batch_size - (len(X) - offset) 
      batch_data = np.concatenate((X[offset: len(X)], X[0: to_add])) 
      batch_labels = np.concatenate((y[offset: len(X)], y[0: to_add])) 
      offset = to_add 

     _, training_loss = sess.run([optimizer, loss], feed_dict = {data : batch_data, labels : batch_labels}) 

     if step % 10 == 0: 
      print('training loss at step %d: %.2f (%s)' % (step, training_loss, datetime.datetime.now())) 

     if step % save_every == 0: 
      saver.save(sess, checkpoint_directory + '/model.ckpt', global_step=step) 

     if step == 30: 
      break 

でチェックポイントを保存しています、私はそのディレクトリを見て、次のファイルがなかったです作成:

enter image description here

彼女E私はおそらく私の訓練を受けたモデルを復元し、それをテストしてい:

with tf.Session(graph=graph) as sess: 
    #standard init step 
    offset = 0 
    saver = tf.train.Saver() 
    saver.restore(sess, "/ckpt/model-150.meta") 
    tf.global_variables_initializer().run() 

    test_start = "I plan to make this world a better place " 
    test_generated = test_start 

.... 

これを行った後、私は次のエラーを取得する:

DataLossError (see above for traceback): Unable to open table file /ckpt/model.ckpt-30.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator? 

私は私が間違ってやっているかなりよく分かりません。チュートリアルはかなり簡単なようですが、私は明らかに何かが欠けています。どんなフィードバックでも大歓迎です。

答えて

1

最初に、チェックポイントから復元した後にすべての変数を初期化すると、訓練された値の代わりにランダムな初期値が得られることに注意してください。

第2に、これを自分で実装する代わりにtf.estimator.Estimatorを使用すると、保存/復元するのがはるかに簡単です。

第3に、model-150.metaをどのように渡して復元するのかわかりませんが、model-30.metaについてのエラーが表示されます。私は信じていますが、あなたはmodel-30(.meta接尾辞なし)だけを渡すべきです。

関連する問題