2017-04-04 18 views
2

保存したチェックポイントからglobal_stepを取得できないようです。私のコード:Tensorflowはチェックポイントからglobal_stepを復元できません

//(...) 
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir) 
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file), clear_devices=True) 
saver.restore(sess, checkpoint_file) 
for v in tf.global_variables(): 
    print(v) 
test = tf.get_variable("global_step") 
print(test) 

結果:

//(...) 
Tensor("global_step/read:0", shape=(), dtype=int32) 
//(...) 
Traceback (most recent call last): 
    File "train.py", line XXX, in <module> 
    test = tf.get_variable("global_step") 
    File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 988, in get_variable 
    custom_getter=custom_getter) 
    File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 890, in get_variable 
    custom_getter=custom_getter) 
    File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 348, in get_variable 
    validate_shape=validate_shape) 
    File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 333, in _true_getter 
    caching_device=caching_device, validate_shape=validate_shape) 
    File "(...)/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 660, in _get_single_variable 
    "but instead was +1ms." % (name, shape)) 
ValueError: Shape of a new variable (global_step) must be fully defined, but instead was <unknown>. 

私もglobal_step:0と​​が、同じ結果と試みました。任意のヒント?またはtf.get_variableを使用しないでください。

は、その変数が最初の場所でtf.get_variableで作成された場合にのみ、既存の変数を取得するためにtf.get_variableを使用することができますあなたの

答えて

1

ありがとうございます。また、可変スコープを適切に設定する必要があります。ここには'global_step'という新しい変数を作成しようとしているようですが、まだ存在しないことを示しています。 Hereは、tf.get_variableの使用方法の詳細です。あなたができる

global_step = tf.get_default_graph().get_tensor_by_name('global_step:0') 
0

# to create 
global_step = tf.Variable(tf.constant(0), trainable=False, name='global_step') 
tf.add_to_collection('global_step', global_step) 

# to load 
global_step = tf.get_collection_ref('global_step')[0] 
# get the current value 
gs = sess.run(global_step) 

編集:あなたはグローバル段差を保存する方法を変更できない場合は、次のように動作するはずです、私は通常、このようなグローバル段差を扱う

このようにする:

with tf.Session() as sess: 
    predict_top_5 = tf.nn.top_k(scores, k=5) 
    label_top_5 = tf.nn.top_k(input_y, k=5) 
    ckpt = tf.train.get_checkpoint_state('models') 
    if ckpt and ckpt.model_checkpoint_path: 
     saver.restore(sess,ckpt.model_checkpoint_path) 
     global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 
関連する問題