2017-11-14 11 views
2

with tf.device("/gpu:0")と併用することになっていたと思いますが、どこに入れますか? Tensorflow - tf.Estimator()のCPUの代わりにGPUを使用する方法CNN

with tf.device("/gpu:0"): 
    tf.app.run() 

は、だから私はそれがtf.appmain()機能、または私は推定に使用するモデル関数に入れなければならない

:私はそれがだとは思いませんか?

EDIT:あなたが見ることができるように、私はどこでも、ここでセッションの明示的な宣言を持っていないので、場所を正確に私はwith tf.device("/gpu:0")を入れてください

def main(unused_argv): 
    """Code to load training folds data pickle or generate one if not present""" 

    # Create the Estimator 
    mnist_classifier = tf.estimator.Estimator(
     model_fn=cnn_model_fn2, model_dir="F:/python_machine_learning_codes/tmp/custom_age_adience_1") 

    # Set up logging for predictions 
    # Log the values in the "Softmax" tensor with label "probabilities" 
    tensors_to_log = {"probabilities": "softmax_tensor"} 
    logging_hook = tf.train.LoggingTensorHook(
     tensors=tensors_to_log, every_n_iter=100) 

    # Train the model 
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
     x={"x": train_data}, 
     y=train_labels, 
     batch_size=64, 
     num_epochs=None, 
     shuffle=True) 
    mnist_classifier.train(
     input_fn=train_input_fn, 
     steps=500, 
     hooks=[logging_hook]) 

    # Evaluate the model and print results 
    """Code to load eval fold data pickle or generate one if not present""" 

    eval_logs = {"probabilities": "softmax_tensor"} 
    eval_hook = tf.train.LoggingTensorHook(
     tensors=eval_logs, every_n_iter=100) 
    eval_input_fn = tf.estimator.inputs.numpy_input_fn(
     x={"x": eval_data}, 
     y=eval_labels, 
     num_epochs=1, 
     shuffle=False) 
    eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn, hooks=[eval_hook]) 

:このことができます場合は、これは私のmain()機能ですか?推定で

答えて

1

は残念ながらtensorflowウェブでの完全なドキュメントがありません...

sess.run() 

よう

sess = tf.Session(config = xxxxxxxxxxxxx) 

どちらのステートメントのような任意の文ではありません。私はのrunConfig

# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which 
# trains faster than GPU for this model. 
run_config = tf.estimator.RunConfig().replace(
     session_config=tf.ConfigProto(log_device_placement=True, 
             device_count={'GPU': 0})) 

の異なるオプションをしようとしている は、私はいくつかの進歩を取得する場合、私はそれをここに掲載しますので、実際に私はあなたの仕事のようなもので働いている...これで動作するようにしてください。

こちらをご覧ください: https://github.com/tensorflow/models/blob/master/official/wide_deep/wide_deep.py この例では、彼らはコードを使用しているモデルは、CPU上で実行されていることを確認するために.replace文で上記示しました。

関連する問題