2016-12-30 9 views
3

TensorFlowのワンホット値から単一の整数にデコードまたは変換する方法の例は見つかりませんでした。Tensorflowでワンホットラベルをどのようにデコードしますか?

私はtf.one_hotを使用していましたが、私のモデルを訓練することができましたが、分類後にラベルを理解する方法についてちょっと混乱しています。私のデータは、私が作成したTFRecordsファイルを介してフィードされています。私はファイルにテキストラベルを格納することを考えましたが、動作させることができませんでした。あたかもTFRecordsがテキスト文字列を保存できないかのように見えたのですが、誤解されたかもしれません。

+0

ちょっと@Matt、答えはあなたの問題を解決しますか? – martianwars

+0

はい、感謝しました! –

答えて

9

tf.argmaxを使用して、マトリックス内の最大の要素のインデックスを見つけることができます。あなたの1つのホットベクトルは1次元であり、1と他の0の1つしか持たないので、これはあなたが単一のベクトルを扱っていると仮定して動作します。 batch_size * num_classesのより標準的な行列については

index = tf.argmax(one_hot_vector, axis=0) 

、サイズbatch_size * 1の結果を得るためにaxis=1を使用しています。

6

通常、1ホット・エンコーディングは、batch_size行とnum_classes列の行列であり、各行はすべてゼロであり、選択されたクラスに対応する非ゼロが1つのみであるため、tf.argmax()を使用して整数のベクトルラベル:

BATCH_SIZE = 3 
NUM_CLASSES = 4 
one_hot_encoded = tf.constant([[0, 1, 0, 0], 
           [1, 0, 0, 0], 
           [0, 0, 0, 1]]) 

# Compute the argmax across the columns. 
decoded = tf.argmax(one_hot_encoded, axis=1) 

# ... 
print sess.run(decoded) # ==> array([1, 0, 3]) 
+0

OPは1つの値から*単一の整数*を求めているので、OPは単なるベクトルを使用しているようです – martianwars

関連する問題