2017-06-08 6 views
1

を使用して、TensorFlow 1.0を使用してTFRecordファイルのデータのバッチを消費しようとしています。関連する関数は、次のとおりTensorFlowで `TFRecord`ファイルを読むときに` tf.train.shuffle_batch`がクラッシュする

def tfrecord_to_graph_ops(filenames_list): 
    file_queue = tf.train.string_input_producer(filenames_list) 
    reader = tf.TFRecordReader() 
    _, tfrecord = reader.read(file_queue) 

    tfrecord_features = tf.parse_single_example(
     tfrecord, 
     features={'targets': tf.FixedLenFeature([], tf.string)} 
    ) 
    ## if no reshaping: `ValueError: All shapes must be fully defined` in 
    ## `tf.train.shuffle_batch` 
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8) 
    ## if using `strided_slice`, always get the first record 
    # targets = tf.cast(
    #  tf.strided_slice(targets, [0], [1]), 
    #  tf.int32 
    #) 
    ## error on shapes being fully defined 
    # targets = tf.reshape(targets, []) 
    ## get us: Invalid argument: Shape mismatch in tuple component 0. 
    ## Expected [1], got [1000] 
    targets.set_shape([1]) 
    return targets 


def batch_generator(filenames_list, batch_size=BATCH_SIZE): 
    targets = tfrecord_to_graph_ops(filenames_list) 
    targets_batch = tf.train.shuffle_batch(
     [targets], 
     batch_size=batch_size, 
     capacity=(20 * batch_size), 
     min_after_dequeue=(2 * batch_size) 
    ) 
    targets_batch = tf.one_hot(
     indices=targets_batch, depth=10, on_value=1, off_value=0 
    ) 
    return targets_batch 


def examine_batches(targets_batch): 
    with tf.Session() as sess: 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     for _ in range(10): 
      targets = sess.run([targets_batch]) 
      print(targets) 
     coord.request_stop() 
     coord.join(threads) 

コードがbatch_generator()の出力を渡された、examine_batches()通って入ります。 batch_generator()tfrecord_to_graph_ops()と呼ばれ、問題はその機能にあります。

I 1,000バイト(番号0-9)でファイルに

targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8) 

を呼び出しています。私がこのセッションでeval()と呼ぶと、それは私に1,000要素すべてを表示します。しかし、バッチジェネレータに入れようとするとクラッシュします。

targetsを再構成しないと、tf.train.shuffle_batchを呼び出すとValueError: All shapes must be fully definedのようなエラーが発生します。 GoogleのCIFAR-10 example codeを彷彿とさせるtargets.set_shape([1])と電話をかけると、のようなエラーがtf.train.shuffle_batchに届きます。私はまた、tf.strided_sliceを使って生データのチャンクをカットしようとしました - これはクラッシュしませんが、最初のイベントを何度も何度も何度も何度も繰り返すことになります。

これを行う正しい方法は何ですか?ファイルTFRecordからバッチを引き出すには?

生のバイトデータを細かく切り詰めた関数を手作業で書くことができました。特に、グラフにデータを取り込む方法がfeed_dictの場合は特に簡単ですが、 TensorFlowのTFRecordファイルを使用し、組み込みのバッチ関数を使用する方法について説明します。

ありがとうございます!

+1

デコードされた文字列には常に1000要素がありますか?そうであれば、形状を1ではなく1000に設定することができます。それ以外の場合は、固定された形状を得るための一般的な解決策です。あるいは、これらの文字をキューに個別に追加したい場合は、 'enuffue_many = True'を' shuffle_batch'に使うことができます。 –

+0

いいえ、数字は変わります - これは単なるおもちゃの問題です。 'enqueue_many'は面白いアイディアです。 –

+0

'enqueue_many = True'がトリックであることが判明しました! –

答えて

1

Allen Lavoieは、コ​​メントに正しい解決方法を指摘しました。欠けている重要な部分は、tf.train.shuffle_batch()の引数としてenqueue_many=Trueでした。これらの関数を書く正しい方法は次のとおりです。

def tfrecord_to_graph_ops(filenames_list): 
    file_queue = tf.train.string_input_producer(filenames_list) 
    reader = tf.TFRecordReader() 
    _, tfrecord = reader.read(file_queue) 

    tfrecord_features = tf.parse_single_example(
     tfrecord, 
     features={'targets': tf.FixedLenFeature([], tf.string)} 
    ) 
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8) 
    targets = tf.reshape(targets, [-1]) 
    return targets 

def batch_generator(filenames_list, batch_size=BATCH_SIZE): 
    targets = tfrecord_to_graph_ops(filenames_list) 
    targets_batch = tf.train.shuffle_batch(
     [targets], 
     batch_size=batch_size, 
     capacity=(20 * batch_size), 
     min_after_dequeue=(2 * batch_size), 
     enqueue_many=True 
    ) 
    return targets_batch 
関連する問題