TensorFlowを初めて使用していて、2列、1列のインデックス、2列のラベル文字列を含むコンマ区切り値(csv)ファイルを読んでみたいと思います。私はcsvファイルの行を1行ずつ読み込む次のコードを持っており、print文を使ってcsvファイルのデータを正しく取得できます。しかし、私は文字列ラベルからワンホットエンコーディング変換を行い、TensorFlowでそれを行う方法はしたくありません。最後の目標は、tf.train.batch()関数を使用して、ニューラルネットワークを訓練するための1つのホットラベルベクトルのバッチを得ることです。TensorFlowで文字列ラベルをワンホットベクトルに変換する方法は?
以下のコードでわかるように、TensorFlowセッション内でラベル項目ごとに1ホットベクトルを手動で作成できます。しかし、どのようにtf.train.batch()関数を使うのですか?行を移動すると、TensorFlowセッションブロック(col2をlabel_one_hotに置き換える)に行を移動すると、プログラムは何もしないでブロックします。私はTensorFlowセッションの外でワンホットベクトル変換を移動しようとしましたが、正しく動作するようにはできませんでした。それを行う正しい方法は何ですか?助けてください。
label_files = []
label_files.append(LABEL_FILE)
print "label_files: ", label_files
filename_queue = tf.train.string_input_producer(label_files)
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
print "key:", key, ", value:", value
record_defaults = [['default_id'], ['default_label']]
col1, col2 = tf.decode_csv(value, record_defaults=record_defaults)
num_lines = sum(1 for line in open(LABEL_FILE))
label_batch = tf.train.batch([col2], batch_size=5)
with tf.Session() as sess:
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coordinator)
for i in range(100):
column1, column2 = sess.run([col1, col2])
index = 0
if column2 == 'airplane':
index = 0
elif column2 == 'automobile':
index = 1
elif column2 == 'bird':
index = 2
elif column2 == 'cat':
index = 3
elif column2 == 'deer':
index = 4
elif column2 == 'dog':
index = 5
elif column2 == 'frog':
index = 6
elif column2 == 'horse':
index = 7
elif column2 == 'ship':
index = 8
elif column2 == 'truck':
index = 9
label_one_hot = tf.one_hot([index], 10) # depth=10 for 10 categories
print "column1:", column1, ", column2:", column2
# print "onehot label:", sess.run([label_one_hot])
print sess.run(label_batch)
coordinator.request_stop()
coordinator.join(threads)
あなたは順番に
tf.one_hot
経由でワンホットベクトルに変換します、プレースホルダ、にあなたの
index
変数を養うために試してみたいことがあり
誰かがあなたのコードを修正したり、TFでone-hotエンコーダを使用する方法を教えたりしますか? –
私は上記のコードに示すように、TFのワンホットエンコーダの使い方を知っていると思います。 TFの計算グラフを使って、文字列ラベルをワンホットベクトルに変換するにはどうすればいいでしょうか。ありがとうございました。 –
基本的には、col2が10個のラベル文字列の1つと等しく、ラベルインデックスを取得してからcol2の値を1ホットベクトルに変更するときです。すべてのエントリに対してこれを行い、その後、1つのホットラベルのバッチ全体を返します。 –