を使用して、TensorFlow 1.0を使用してTFRecord
ファイルのデータのバッチを消費しようとしています。関連する関数は、次のとおりTensorFlowで `TFRecord`ファイルを読むときに` tf.train.shuffle_batch`がクラッシュする
def tfrecord_to_graph_ops(filenames_list):
file_queue = tf.train.string_input_producer(filenames_list)
reader = tf.TFRecordReader()
_, tfrecord = reader.read(file_queue)
tfrecord_features = tf.parse_single_example(
tfrecord,
features={'targets': tf.FixedLenFeature([], tf.string)}
)
## if no reshaping: `ValueError: All shapes must be fully defined` in
## `tf.train.shuffle_batch`
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
## if using `strided_slice`, always get the first record
# targets = tf.cast(
# tf.strided_slice(targets, [0], [1]),
# tf.int32
#)
## error on shapes being fully defined
# targets = tf.reshape(targets, [])
## get us: Invalid argument: Shape mismatch in tuple component 0.
## Expected [1], got [1000]
targets.set_shape([1])
return targets
def batch_generator(filenames_list, batch_size=BATCH_SIZE):
targets = tfrecord_to_graph_ops(filenames_list)
targets_batch = tf.train.shuffle_batch(
[targets],
batch_size=batch_size,
capacity=(20 * batch_size),
min_after_dequeue=(2 * batch_size)
)
targets_batch = tf.one_hot(
indices=targets_batch, depth=10, on_value=1, off_value=0
)
return targets_batch
def examine_batches(targets_batch):
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(10):
targets = sess.run([targets_batch])
print(targets)
coord.request_stop()
coord.join(threads)
コードがbatch_generator()
の出力を渡された、examine_batches()
通って入ります。 batch_generator()
はtfrecord_to_graph_ops()
と呼ばれ、問題はその機能にあります。
I 1,000バイト(番号0-9)でファイルに
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
を呼び出しています。私がこのセッションでeval()
と呼ぶと、それは私に1,000要素すべてを表示します。しかし、バッチジェネレータに入れようとするとクラッシュします。
targets
を再構成しないと、tf.train.shuffle_batch
を呼び出すとValueError: All shapes must be fully defined
のようなエラーが発生します。 GoogleのCIFAR-10 example codeを彷彿とさせるtargets.set_shape([1])
と電話をかけると、のようなエラーがtf.train.shuffle_batch
に届きます。私はまた、tf.strided_slice
を使って生データのチャンクをカットしようとしました - これはクラッシュしませんが、最初のイベントを何度も何度も何度も何度も繰り返すことになります。
これを行う正しい方法は何ですか?ファイルTFRecord
からバッチを引き出すには?
生のバイトデータを細かく切り詰めた関数を手作業で書くことができました。特に、グラフにデータを取り込む方法がfeed_dict
の場合は特に簡単ですが、 TensorFlowのTFRecord
ファイルを使用し、組み込みのバッチ関数を使用する方法について説明します。
ありがとうございます!
デコードされた文字列には常に1000要素がありますか?そうであれば、形状を1ではなく1000に設定することができます。それ以外の場合は、固定された形状を得るための一般的な解決策です。あるいは、これらの文字をキューに個別に追加したい場合は、 'enuffue_many = True'を' shuffle_batch'に使うことができます。 –
いいえ、数字は変わります - これは単なるおもちゃの問題です。 'enqueue_many'は面白いアイディアです。 –
'enqueue_many = True'がトリックであることが判明しました! –