2017-08-28 12 views
1

私は、次のtensorflowのチュートリアルからmnist畳み込みネットワークを適応されています: https://www.tensorflow.org/tutorials/layersテンソルフローでは、ジェネレータから予測をどのように読み取ることができますか?

私は入力の形状を除いて同じコードを使用します。私が訓練と評価をするとき、私は良い結果を得ます。しかし、私は予期せぬことが分かっているので、何が分からないのか分かりました。

y=SN_classifier.predict(input_fn=my_data_to_predict) 

を実行しているときmy_data_to_predictが右形状のnumpyの配列ですしかし、私は次のような出力が得られます。

私はそれを私のやって読むことができるはずフォーラムで読んだ
<generator object Estimator.predict at 0x7fb1ecefeaf0> 

: yのiに対する :それは 'numpy.ndarray' オブジェクトを提起

しかし プリント(i)は呼び出すことはできません

私は他のフォーラムで読ん
print('Predictions: {}'.format(list(y)) 

..

あなたはそれが私の予測出力しない理由のいずれかのアイデアを持っているだろう:私がしようとした場合と同じはどうなりますか?ここで

は私が予測する定義するコードの一部です:

predictions = { 
     # Generate predictions (for PREDICT and EVAL mode) 
     "classes": tf.argmax(input=logits, axis=1), 
     # Add `softmax_tensor` to the graph. It is used for PREDICT and by the 
     # `logging_hook`. 
     "probabilities": tf.nn.softmax(logits, name="softmax_tensor") 
    } 
    if mode == tf.estimator.ModeKeys.PREDICT: 
     return(tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)) 

そして、私はそれを呼び出す:

y=SN_classifier.predict(input_fn=my_data_to_predict) 

はあなたの助けのためにありがとうございました、私はいずれかを取りますよアドバイス、アイデア:)

答えて

1

input_fnはテンソルを生成する関数でなければなりません。 numpy_input_fnにラップしてください。

input_fn = tf.estimator.inputs.numpy_input_fn(my_data_to_predict) 
for single_prediction in SN_classifier.predict(input_fn): 
    predicted_class = single_prediction['class'] 
    probability = single_prediction['probability'] 
    do_something_with(predicted_class, probability) 
関連する問題