2017-01-23 9 views
0

分散テンソルフローを使用して、パラメーターサーバーを使用してストリーミング形式でモデルをトレーニングしたいと考えています。私はGraph is finalized and cannot be modifiedが第二のパスで入力演算子を追加することをランタイムエラーを取得し、この設定で分散演算が開始された後、入力演算子をテンソルフローグラフに追加するにはどうすればよいですか?

def train_model(filenames, params): 
    with tf.device(tf.train.replica_device_setter(
      worker_device='/job:worker/task:%d' % params.task_index, cluster=cluster)): 
     input_op = construct_input_op(filenames) 
     global_step = tf.Variable(0) 
     train_op = construct_train_op(input_op, global_step, params) 

    init_op = tf.global_variables_initializer() 
    saver = tf.train.Saver(tf.global_variables() + tf.local_variables()) 

    supervisor = tf.train.Supervisor(
     is_chief=params.task_index == 0, 
     logdir=params.training_summary_dir, 
     init_op=init_op, 
     saver=saver, 
     global_step=global_step, 
     save_model_secs=0) 

    with supervisor.managed_session(server.target) as sess: 
     while not supervisor.should_stop() and step <= params.max_steps: 
      sess.run(train_op) 

    supervisor.stop() 

cluster = tf.train.ClusterSpec({"ps": [params.param_server_host], "worker": params.worker_hosts}) 
server = tf.train.Server(cluster, job_name="worker", task_index=params.task_index) 
while True: 
    filenames = wait_for_new_training_data(...) 
    train_model(filenames, params) 

:労働者のセットアップはhttps://www.tensorflow.org/how_tos/distributed/に基づいて、このようなものです。このサンプルをどのように動作させることができますか?

+0

グラフオブジェクトに対して '._unsafe_unfinalize()'を試すことはできますが、必要な操作をすべて前もって作成する方が堅牢な解決策になります(実行中にグラフを修正することによって生じるパフォーマンス/ /同時に) –

+0

@ YaroslavBulatov - これをあなたの答えとすることができますか?ありがとう! – dga

答えて

0

グラフオブジェクトで._unsafe_unfinalize()を試すことはできますが、必要な操作をすべて前もって作成する方がより堅牢なソリューションになります(グラフの実行中にグラフを変更することによるパフォーマンス/スレッドセーフの問題があります) )

関連する問題