2017-05-07 8 views
0

2つのpythonファイルがあります。最初はテンソルフロー モデルを保存するためのファイルです。 2番目は、保存されたモデルを復元するためのものです。保存されたテンソルフローモデルを復元するには?

質問:私は2つのファイルを次々に実行すると、それは大丈夫だ

  1. 私が最初のものを実行すると、編集を再開して2番目のものを実行します。 は、w1が定義されていないことを教えてくれますか?

私は何をしたいです:

  1. 保存tensorflowモデル

  2. が保存されたモデル、それと間違って何

を復元しますか?あなたの親切に助けてくれてありがとう?

model_save.py

import tensorflow as tf 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1') 
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
saver = tf.train.Saver() 

with tf.Session() as sess: 
sess.run(tf.global_variables_initializer()) 
saver.save(sess, 'SR\\my-model') 

model_restore.py

import tensorflow as tf 

with tf.Session() as sess:  
saver = tf.train.import_meta_graph('SR\\my-model.meta') 
saver.restore(sess,'SR\\my-model') 
print (sess.run(w1)) 

enter image description here

答えて

1

簡単に言うと、あなたが使用する必要があります

print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0'))) 

model_restore.pyファイルにprint (sess.run(w1))の代わりに入力してください。

model_save.py

import tensorflow as tf 
w1_node = tf.Variable(tf.random_normal(shape=[2]), name='w1') 
w2_node = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
saver = tf.train.Saver() 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(w1_node.eval()) # [ 0.43350926 1.02784836] 
    #print(w1.eval()) # NameError: name 'w1' is not defined 
    saver.save(sess, 'my-model') 

w1_nodeだけmodel_save.pyで定義され、model_restore.pyファイルがそれを認識することはできません。 変数をnameと呼ぶときは、get_tensor_by_nameを使用する必要があります。この投稿はTensorflow: How to get a tensor by name?です。ご親切に答えを

model_restore.py

import tensorflow as tf 

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('my-model.meta') 
    saver.restore(sess,'my-model') 
    print (sess.run(tf.get_default_graph().get_tensor_by_name('w1:0'))) 
    # [ 0.43350926 1.02784836] 
    print(tf.global_variables()) # print tensor variables 
    # [<tf.Variable 'w1:0' shape=(2,) dtype=float32_ref>, 
    # <tf.Variable 'w2:0' shape=(5,) dtype=float32_ref>] 
    for op in tf.get_default_graph().get_operations(): 
    print str(op.name) # print all the operation nodes' name 
+0

感謝。 あなたの提案と関連する参考文献によると、 私はちょうど問題を解決しました:) –

関連する問題