2017-03-22 18 views
2

分散型テンソルフローモデルを設定したいが、MonitoredTrainingSession & StopAtStepHookがどのようにやりとりするかを理解できない。 私は、このセットアップ持っていた前に:基本StopAtStepHook&MonitoredTrainingSessionの使用

for epoch in range(training_epochs): 
    for i in range(total_batch-1): 
    c, p, s = sess.run([cost, prediction, summary_op], feed_dict={x: batch_x, y: batch_y}) 

今、私は、このセットアップ(簡体字)している:

def run_nn_model(learning_rate, log_param, optimizer, batch_size, layer_config): 
    with tf.device(tf.train.replica_device_setter(
     worker_device="/job:worker/task:%d" % mytaskid, 
     cluster=cluster)): 

    # [variables...] 

    hooks=[tf.train.StopAtStepHook(last_step=100)] 
    if myjob == "ps": 
     server.join() 
    elif myjob == "worker": 
     with tf.train.MonitoredTrainingSession(master = server.target, 
           is_chief=(mytaskid==0), 
           checkpoint_dir='/tmp/train_logs', 
           hooks=hooks 
           ) as sess: 

      while not sess.should_stop(): 
      #for epoch in range...[see above] 

を、これは間違ってますか?それはスローされます:

RuntimeError: Run called even after should_stop requested. 
Command exited with non-zero status 1 

テンソルフローはここでどのように調整されていますか?ステップカウンターを使用してトレーニングを追跡するにはどうすればよいですか? (私がこの便利な時代変数を使う前に)

答えて

1

sess.runが実行されるたびに、カウンタがインクリメントされます。ここでの問題は、フック(200)で指定されたステップ数より多くのステップ(total_batch-1 x training_epochs)を実行していることです。

私はそれがきれいな構文だとは思わないが、あなたができることは、定義last_step = total_batch-1 x training_epochsです。

関連する問題