2017-01-10 27 views
0

tf.contrib.learn.ReadBatchFeatureshttps://www.tensorflow.org/versions/master/api_docs/python/contrib.learn/input_processing#read_batch_features)を使用して、Tensorオブジェクトのdictを返す入力関数の一部としてExample protosを読み込みます。私のモデルを訓練した後、私のEstimatorpredictと呼ぶと、予測としてのバッチが配列として返されます。これは既知の値と比較したいと思います。TensorFlow:配列にバッチフィーチャーを読み込む

tf.Session().run(labels)labelsは既知の値のTensor)を呼び出して既知の値を取得しようとします。入力関数から返されます。しかし、この時点で私のプログラムはハングします。私は、ディスクからラベルを読み取る無限ループで止まっていると思っています。ちょうど1つのバッチを読みたいのではないでしょうか。

これは、labelsTensorの1つのバッチ値を取得する正しい方法ですか?

編集:私はキューのランナーを開始しようとしましたが、以下は正しいですか?

_, labels = eval_input_fn() 
with tf.Session().as_default(): 
    tf.local_variables_initializer() 
    tf.train.start_queue_runners() 
    label_values = labels.eval() 
print(label_values) 
+1

多分あなたは、キューランナーを起動しなかったし、あなたのキューが空になりますか?空のキューからの読み取りがハングアップする –

+0

ありがとうございます。ドキュメントにはこれが記載されていることを覚えていますが、忘れました。私のセッションはまだハングアップしていますが、何が間違っているのか教えてください。 – Dimpl

+0

タイムアウトを追加すると、 'config = tf.ConfigProto(); config.operation_timeout_in_ms = 60000; sess = tf.InteractiveSession(config = config) ' –

答えて

2

あなたが必要とする全体のセットアップは次のとおりです。

_, labels = eval_input_fn() 

with tf.Session() as sess: 
     sess.run([ 
      tf.local_variables_initializer(), 
      tf.global_variables_initializer() 
     ]) 

     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

     try: 
      while not coord.should_stop(): 
       print(sess.run(label)) 

     except tf.errors.OutOfRangeError as error: 
      coord.request_stop(error) 
     finally: 
      coord.request_stop() 
      coord.join(threads) 
関連する問題