2016-12-30 16 views
1

次のコードを考える:Tensorflow:セッションを保存し、resoring - 複数の変数

import tensorflow as tf 

with tf.Session() as sess: 
    var = tf.Variable(42, name='var') 
    sess.run(tf.global_variables_initializer()) 
    tf.train.export_meta_graph('file.meta') 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('file.meta') 
    print sess.run(var) 

を私はValueError: At least two variables have the same name: varを言っラインsaver = tf.train.import_meta_graph('file.meta')でエラーが発生します。

どうすればこの問題を回避できますか?とにかくは、のメタグラフをインポートするときに計算グラフを上書きしますか?

EDIT:

私は、次のコードに到着した:

import tensorflow as tf 

file_name = "./file" 

with tf.Session() as sess: 
    var = tf.Variable(42, name='my_var') 
    sess.run(tf.global_variables_initializer()) 

    saver = tf.train.Saver() 
    saver.save(sess,file_name) 
    saver.export_meta_graph(file_name + '.meta') 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph(file_name + '.meta') 
    saver.restore(sess, file_name) 
    print(sess.run(var)) 

    # new code that fails: 
    saver = tf.train.Saver() 
    saver.save(sess,file_name) 
    saver.export_meta_graph(file_name + '.meta') 

これはvarに正しい値を出力しますが、私はグラフを二時間を節約するために来るとき、私が取得同じ元のエラー:ValueError: At least two variables have the same name: var

+0

ちょっと@トム、あなたは私の答えに満足していますか? – martianwars

+0

Hey @martianwars、私の編集を参照してください – Tom

+0

Hey @martianwarsは 'reset_default_graph'を含むように編集しました – Tom

答えて

2

この場合、既に変数が定義されている既定のグラフに変数を読み込みます。したがって、 をインポートする前にTensorFlowグラフをリセットする必要があります。

tf.reset_default_graph()を使用してください。インポートする前にExporting and Importing a MetaGraphの下の「デフォルトグラフ内のインポート」セクションをご覧ください。

もちろん、変数vartf.get_variable()を使って再定義する必要があります。このコードを試してみてください、あなたの中間コードが働いていなかった

import tensorflow as tf 

with tf.Session() as sess: 
    var = tf.Variable(42, name='var') 
    sess.run(tf.global_variables_initializer()) 
    tf.train.export_meta_graph('file.meta') 
tf.reset_default_graph() 
with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('file.meta') 
    var = tf.global_variables()[0] 
    sess.run(tf.initialize_all_variables()) 
    print sess.run(var) 

理由はtf.get_variable()がランダムに初期化された新しい変数を作成していたということでした。最初にtf.get_variable_scope().reuse_variables()を必ず実行してください。 Understanding tf.get_variable()をご覧ください。

tf.Variable()を使って作成した変数は、tf.get_variable()で直接再利用することはできません。理由を正確に知るには、commentcommentをご覧ください。したがって、今後再利用したい場合は、tf.get_variable()を作成する必要があります。

関連する問題