2017-07-12 18 views
0

tf.one_hotのインデックスはどのような形式ですか?テンソルにすることはできますか?私は今、このtensorflow one_hot入力サイズ

sess.run(prediction, feed_dict={x:batch_x, y:batch_y}) 
# [1 1 1 1 0 0 0 0 0] an array of either zero, one 

を実行する場合は、私が実行して、

# [ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ] 

しかし、これは2次元の配列になりたい、

prediction = tf.argmax(output, axis=1) 
pred_hot = tf.one_hot(indices = predictions, depth=2) 

私は、次のコードを持っています

sess.run(pred_hot, feed_dict={x:batch_x, y:batch_y}) 

エラーを示します。

私はtf.one_hot関数にどのような形式を入力すればよいのでしょうか?なぜこのように動作しないのですか?あなたは2次元のアレイとして供給された入力を供給する必要がないので、

答えて

0

tf.one_hot()は、[1 1 1 1 0 0 0 0 0][ [0,1], [0,1], [0,1], [0,1], [1,0], [1,0], [1,0], [1,0], [1,0], [1,0] ]に変換します。一般的な例として

、入力[1 0 2]であり、深さが3次にtf.one_hot()[ [0 1 0] [ 1 0 0] [0 0 1] ]

に変換されていると
0

深さ= 2を使用してみてください、すなわち

tf.one_hot(tensor, depth=2) 
関連する問題