2017-02-22 2 views
1

TF Dev Summit talkを見ると、TensorFlowグラフのXLAをオンにする可能性があり、非常に実験的なメリットがあるので、少し実験したいと思っていました。tf.contrib.learn.Estimatorを使用するときXLAを有効にする

質問tf.contrib.learn.Estimatorを使用する場合、JIT XLAを有効にするにはどうすればよいですか?

Iは

with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"): 
    output = tf.add(input1, input2) 

によるJIT XLAフラグ特定のOPSしかしながら、ドキュメントは、これは単にテストのために意図されていることを警告することができます。私が推奨される方法

# Config to turn on JIT compilation 
config = tf.ConfigProto() 
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 

sess = tf.Session(config=config) 

を使用してこれを実行できるようにしたいと思いますが、私はmodel_fn外からtf.Sessionを設定する方法を把握することはできません。

この擬似コードは、問題をより明確にする:

def char_cnn_model(features, target, mode, params, model_dir): 
    # do stuff to build the CNN 

    ... 

    return tf.contrib.learn.ModelFnOps(mode=mode, 
     predictions=predictions_dict, 
     loss=loss, 
     train_op=train_op, 
     eval_metric_ops=eval_metric_ops) 

def main(unused_argv): 
    # load training data, test data etc 

    ... 

    classifier = learn.Estimator(
     model_fn=char_cnn_model, # defined above 
     model_dir=model_dir, 
     params=params, 
     config=tf.contrib.learn.RunConfig(save_checkpoints_secs=60, 
              log_device_placement=True, 
              tf_random_seed=7)) 
    classifier.fit(
     X_train, 
     y_train, 
     steps=5000, 
     monitors=[validation_monitor]) # validation_monitor defined in main 

    accuracy_score = classifier.evaluate(x=X_test, y=y_test)          

tf.contrib.learn.RunConfigは良い候補のように思えたが、それは、私は理にかなって推測され(セッションのために何かを持っていない、なぜのrunConfigを持っているでしょう既存のセッションへのポインタ?)

私はこれをやりすぎる可能性があります。tf.get_default_sessionは私が必要とするものですが、作成後にセッションの設定を変更することはできますか?

答えて

2

まだ見えません。 MonitoredSession is being instantiated directlyとオプションはローカルに渡されます。 列車メソッドをサブクラス化し、XLA構成を使用する監視対象セッションでオーバーライドすることのみが可能です。

関連する問題