2017-07-17 20 views
4

LSTMの保存と復元に役立つことができれば、本当に感謝しています。LSTMレイヤを復元する方法

# LSTM cell 
cell = tf.contrib.rnn.LSTMCell(n_hidden) 
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32) 

outputs = tf.transpose(output, [1, 0, 2]) 
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1) 

# Saver function 
saver = tf.train.Saver() 
saver.save(sess, 'test-model') 

セーバーがモデルを保存し、私はLSTMの重みとバイアスを保存し、復元することができます -

私はこのLSTM層を持っています。しかし、私はこのLSTM層を復元し、それに新しい入力を与える必要があります。

は、モデル全体を復元するには、私がやっている:

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('test-model.meta') 
    saver.restore(sess, tf.train.latest_checkpoint('./')) 
  1. それは私が事前に訓練を受けた重みとバイアスでLSTM細胞を初期化することは可能ですか?

  2. このLSTM層を復元するにはどうすればよいですか?

ありがとうございました!

答えて

1

すでにモデルがロードされており、モデルの重みもロードされています。グラフからテンソルを取得して推論に使用するには、get_tensor_by_nameを使用するだけです。

例:上記の例word_vecoutputsにおいて

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('test-model.meta') 
    saver.restore(sess, tf.train.latest_checkpoint('./')) 

    # Get the tensors by their variable name 
    word_vec = = detection_graph.get_tensor_by_name('word_vec:0') 
    output_tensor = detection_graph.get_tensor_by_name('outputs:0') 

    sess.run(output_tensor, feed_dict={word_vec: ...}) 

は、グラフの作成時テンソルに割り当てられた名前です。名前で名前を呼び出すことができるように、名前を割り当ててください。

+0

私の質問にお答えいただきありがとうございます!ほんとうにありがとう。だから私はLSTMを通して 'word_vec'を再び渡す必要はありませんか?これはどのように正確に機能しますか? – AnnaR

+0

これは単なる例です。グラフで定義した入力をfeed_dictに渡す必要があります。 –

+0

ありがとうございます! – AnnaR

関連する問題