2017-12-02 11 views
0

私はtf.train.Saver()がどのように動作するのか少し混乱します。私はfollowing codeだけでトレーニング可能な変数を保存する必要があります。tf.train.Saverはどのように正しく動作しますか?

import tensorflow as tf 
# Create some variables. 
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) 
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) 

inc_v1 = v1.assign(v1+1) 
dec_v2 = v2.assign(v2-1) 

# Add an op to initialize the variables. 
init_op = tf.global_variables_initializer() 

# Add ops to save and restore all the variables. 
saver = tf.train.Saver(tf.trainable_variables()) 
print([x.name for x in tf.trainable_variables()]) 
# Later, launch the model, initialize the variables, do some work, and save the 
# variables to disk. 
with tf.Session() as sess: 
    sess.run(init_op) 
    # Do some work with the model. 
    inc_v1.op.run() 
    dec_v2.op.run() 
    # Save the variables to disk. 
    save_path = saver.save(sess, "./model.ckpt") 
    print("Model saved in file: %s" % save_path) 

、次のコードは、ちょうどそれらを見て:

import tensorflow as tf 
sess = tf.Session() 
saver = tf.train.import_meta_graph('model.ckpt.meta') 
saver.restore(sess,'model.ckpt') 
print([v.name for v in tf.get_default_graph().as_graph_def().node]) 

最初のコード出力['v1:0', 'v2:0']、予想通り。私は、同じ結果を生成するために第2のコードを期待していますが、私はこれを参照してください。

['v1/Initializer/zeros', 'v1', 'v1/Assign', 'v1/read', 'v2/Initializer/zeros', 'v2', 'v2/Assign', 'v2/read', 'add/y', 'add', 'Assign', 'sub/y', 'sub', 'Assign_1', 'init', 'save/Const', 'save/SaveV2/tensor_names', 'save/SaveV2/shape_and_slices', 'save/SaveV2', 'save/control_dependency', 'save/RestoreV2/tensor_names', 'save/RestoreV2/shape_and_slices', 'save/RestoreV2', 'save/Assign', 'save/RestoreV2_1/tensor_names', 'save/RestoreV2_1/shape_and_slices', 'save/RestoreV2_1', 'save/Assign_1', 'save/restore_all']

tfのではなく、具体的に述べた2のすべての変数を保存し、なぜ私はわかりません。どうやってやるの?

答えて

0

私はこのことができます願っていtensorflow wiki

tf.reset_default_graph() 

# Create some variables. 
v1 = tf.get_variable("v1", shape=[3], name="v1") 
v2 = tf.get_variable("v2", shape=[5], name="v2") 
saver = tf.train.Saver(var_list=[v1, v2]) # list of TF variables that are to be restored 

with tf.Session() as sess: 
    # Restore variables from disk. 
    saver.restore(sess, "./model.ckpt") 
    print("Model restored.") 
    # Check the values of the variables 
    print("v1 : %s" % v1.eval()) 
    print("v2 : %s" % v2.eval()) 

から次のコードを試してみてください!

+0

私はおそらくこれについてはあまり明確ではありませんでしたが、私の目標はそれらのすべてではなくいくつかの変数を保存して読み込むことです。 – src

+0

@src答えを編集しました –

関連する問題