2017-04-06 9 views
11

私はTensorFlowでモデルトレーニングを再開しようとしています。私はグラフを再構成しないように、最近追加された(0.12+と思う)import_meta_graph()を使用したいと思います。Python TensorFlow:オプティマイザとimport_meta_graphを使ってトレーニングを再開するには?

私はこのためのソリューションを見てきました。 Tensorflow: How to save/restore a model?が、私はAdamOptimizerの問題にぶつかります。具体的には、ValueError: cannot add op with name <my weights variable name>/Adam as that name is already usedエラーが発生します。 This can be fixed by initializingですが、私のモデル値はクリアされています!

そこには他の回答といくつかの完全な例がありますが、それらは常に古いと思われ、新しいimport_meta_graph()アプローチを含まないか、非テンソルオプティマイザを含んでいません。私が見つけることができる最も近い質問はtensorflow: saving and restoring sessionですが、最終的な明確な解決策はなく、例はかなり複雑です。

理想的には、実行可能な単純な例を最初から開始し、停止してからもう一度ピックアップしたいと思います。私は(下)で動作するものを持っていますが、私は何かが欠けているかどうかも疑問に思っています。確かに私はこれを行う唯一の人ではないのですか?

+0

私はAdamOptimizerで同じ問題がありました。私は、私の活動をコレクションに入れて仕事をすることができました。この例は私を助けました:http://www.seaandsailor.com/tensorflow-checkpointing.html –

答えて

4

ここで私は、ドキュメント、他の似たようなソリューション、試行錯誤を読んでから思いついたものです。これは、ランダムなデータの簡単な自動エンコーダーです。走ってから再び走った場合、それは中断していたところから続きます(つまり、最初の走行でのコスト関数は〜0.5から0.3秒の実行開始〜0.3になります)。私が何かを見逃していない限り、すべての保存、コンストラクタ、モデル構築、add_to_collectionが必要であり、正確な順序であるが、より単純な方法があるかもしれない。 グラフをロードするのは実際には上に書かれているので実際には必要ありませんが、私の実際のアプリケーションでは必要なものです。

from __future__ import print_function 
import tensorflow as tf 
import os 
import math 
import numpy as np 

output_dir = "/root/Data/temp" 
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt") 

input_length = 10 
encoded_length = 3 
learning_rate = 0.001 
n_epochs = 10 
n_batches = 10 
if not os.path.exists(model_checkpoint_file_base + ".meta"): 
    print("Making new") 
    brand_new = True 

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in") 
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length], 
              -1.0/math.sqrt(input_length), 
              1.0/math.sqrt(input_length)), name="W_enc") 
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc") 
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded") 
    W_dec = tf.transpose(W_enc, name="W_dec") 
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec") 
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded") 
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost") 

    saver = tf.train.Saver() 
else: 
    print("Reloading existing") 
    brand_new = False 
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta") 
    g = tf.get_default_graph() 
    x_in = g.get_tensor_by_name("x_in:0") 
    cost = g.get_tensor_by_name("cost:0") 


sess = tf.Session() 
if brand_new: 
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) 
    init = tf.global_variables_initializer() 
    sess.run(init) 
    tf.add_to_collection("optimizer", optimizer) 
else: 
    saver.restore(sess, model_checkpoint_file_base) 
    optimizer = tf.get_collection("optimizer")[0] 

for epoch_i in range(n_epochs): 
    for batch in range(n_batches): 
     batch = np.random.rand(50, input_length) 
     _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch}) 
     print("batch_cost:", curr_cost) 
     save_path = tf.train.Saver().save(sess, model_checkpoint_file_base) 
2

私は同じ問題を持っていたと私はちょうど、少なくとも私のコードでは、間違っていたものを考え出しました。

最後に、saver.restore()に間違ったファイル名を使用しました。この関数は、ちょうどsaver.save()関数のように、ファイル拡張子なしでファイル名を指定する必要があります。これにより

saver.restore(sess, 'model-1') 

代わりの

saver.restore(sess, 'model-1.data-00000-of-00001') 

私はあなたが何をしたい正確に何をすべきか:ゼロから始めます、停止して、再びピックアップします。 tf.train.import_meta_graph()関数を使用してメタファイルからセカンダリセーバーを初期化する必要はなく、オプティマイザの初期化後に明示的にtf.initialize_all_variables()を指定する必要はありません。

私の完全なモデルは、このようなルックスを復元:

with tf.Session() as sess: 
    saver = tf.train.Saver() 
    sess.run(tf.global_variables_initializer()) 
    saver.restore(sess, model-1) 

私は、プロトコルのV1にあなたはまだファイル名に.ckptを追加する必要がありました、そしてimport_meta_graph()のためにあなたはまだ可能性がありますこれは、.metaを追加する必要があると思いますユーザー間の混乱。たぶん、これはドキュメントでより明示的に指摘されるべきです。

0

復元セッションで保護オブジェクトを作成するときに問題が発生することがあります。

復旧セッションで以下のコードを使用した場合と同じエラーが発生しました。エラーが解消された

saver = tf.train.Saver() 
saver.restore(sess, "tmp/hsmodel") 

saver = tf.train.import_meta_graph('tmp/hsmodel.meta') 
saver.restore(sess, tf.train.latest_checkpoint('tmp/')) 

しかし、私はこのように変化したときに、。 "tmp/hsmodel"は、セービングセッションでsaver.save(sess、 "tmp/hsmodel")に渡すパスです。

トレーニングMNISTネットワーク(Adamオプティマイザを含む)のセッションの保存と復元に関する簡単な例がここにあります。これは私のコードと比較して問題を解決するのに役立ちました。

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/4_Utils/save_restore_model.py

関連する問題