2017-08-20 13 views
1

更新:tensorflow-cpuを使用すると、以下のコードが正しく機能することがわかりました。この問題は、tensorflow-gpuを使用するときだけ持続します。どうすればそれを動作させることができますか?TensorflowのモデルをGPUで保存していませんか?

私のコードで問題を見つけることができません - 変数を保存しようとしていますが、再読み込みしようとしていて、保存されたモデルから読み込まれていないようです。

同じPython実行で保存と読み込みを行うと、プロセスが終了してテストスクリプトが実行されずに読み込まれることに注意してください。私の問題は、私はモードを訓練するときにこれが動作しない - >保存する - >プロセスが終了する - >テストフラグでスクリプトを再実行する - >モデルはエラーなしで読み込まれますが、

コード:

ラン#1

# creating LSTM model... 

with tf.Session() as sess: 
    saver = tf.train.Saver() 

    # training... 

    save_path = saver.save(sess, "./saved_models/model.ckpt") 
    print("Model saved in file: %s" % save_path) 

ラン#2

# creating the same exact LSTM model... 

with tf.Session() as sess: 
    saver = tf.train.Saver() 

    saver.restore(sess, "./saved_models/model.ckpt") 
    print("Model restored.") 

    # testing... 

私は背中合わせにこれら二つのスニペットを実行した場合、私は所望の出力を得る - モデルが訓練されています簡単なシーケンスを予測し、テスト中に正しく予測します。 2つのスニペットを別々に実行すると、モデルはテスト中に間違ったシーケンスを予測します。

更新:MetaGraphをインポートしようと提案されましたが、どちらも機能しません。コード:

ラン#1

# creating model... 

tf.add_to_collection('a', net.a) 
# adding nodes ... 
tf.add_to_collection('z', net.z) 

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    # training... 
    save_path = saver.save(sess, "./saved_models/my-model") 
    print("Model saved in file: %s" % save_path) 

ラン#2

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('./saved_models/my-model.meta') 
    new_saver.restore(sess, './saved_models/my-model') 

    net.a = tf.get_collection('a')[0] 
    # adding nodes ... 
    net.z = tf.get_collection('z')[0] 

    # testing... 

は、上記のコードは正常に動作します - しかし、TESTSET結果は、私が実行した場合、それは、再びポストトレーニング(とないであることを示しています同じPythonインスタンス内の2つのスニペットは正しく動作します)。

これはかなり些細なことであり、私はそれを動作させることができません。どんな助けも歓迎です。具体的には、グラフ全体を保存する必要はありません。変数(LSTMセル内の変数の一部)だけです。

+0

メタグラフ(tf.train.import_meta_graph)をインポートしてみてください。 https://www.tensorflow.org/programmers_guide/meta_graph –

+0

これは私にとってもうまくいかなかった。質問に関連するコードを追加しました... –

答えて

1

同じ問題が発生しましたが、tf.Variable()、そうでしょうか? tf.get_variable()に変更してください。それは私のために働いた:)

+0

私はget_variableを使用しています –

関連する問題