2017-10-23 7 views
1

デフォルトでは、TensorFlow分散トレーニングでは、非同期分散トレーニングでは、個々のワーカーとパラメータサーバー間の通信だけが必要ですが、ワーカーとパラメータサーバー間のすべての接続が確立されます。tf.contrib.learn.Experimentでdevice_filtersを使用するにはどうすればよいですか?

tf.contrib.learn.Experimentを使用しているときの通信を制限するにはどうすればよいですか?

答えて

1
# The easiest way to parse TF_CONFIG environment variable is to create a RunConfig. 
# Unfortunately, it is an immutable object, so we're going to create a 
# temporary one and only use it for `task_type` and `task_id`. 
tmp = tf.contrib.learn.RunConfig() 
task_type, task_id = tmp.task_type, tmp.task_id 

# We use a device_filter to limit the communication between this job 
# and the parameter servers, i.e., there is no need to directly 
# communicate with the other workers; attempting to do so can result 
# in reliability problems. 
device_filters = [ 
    '/job:ps', '/job:%s/task:%d' % (task_type, task_id) 
] 
session_config = tf.ConfigProto(device_filters=device_filters) 
run_config = tf.contrib.learn.RunConfig(
    model_dir=args.job_dir, 
    session_config=session_config) 

# Create the experiment_fn: 
experiment_fn = ... 

# Run the experiment 
learn_runner.run(experiment_fn, run_config=run_config) 
関連する問題