2017-07-08 7 views
0

ファイルに格納されている1000のサイズの画像があります。データセットを2回(2エポック)反復したいので、tf.train.string_input_producer([dummy.tfrecord], num_epochs=2)と指定します。バッチサイズが100の場合、tf.train.shuffle_batch10010バッチが1000の画像を消耗するので、2 * 10 = 20回の実行を期待します。random_shuffle_queueが要素を使い果たしたときの理解

私は、this answerに続き、予想通りに20反復を生成しました。そこ0画像がキューに残っているので、

RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 100, current size 0) 

理にかなって:しかし、最後に、私はエラーを受け取りました。

キューを閉じて正常に終了するにはどうすればよいですか?つまり、エラーではありません。ここで

は完全なスクリプトです:誰でも再現したい場合

import numpy as np 
import tensorflow as tf 

NUM_IMGS = 1000 
tfrecord_file = 'dummy.tfrecord' 

def read_from_tfrecord(filenames): 
    tfrecord_file_queue = tf.train.string_input_producer(filenames, 
      num_epochs=2) 
    reader = tf.TFRecordReader() 
    _, tfrecord_serialized = reader.read(tfrecord_file_queue) 

    tfrecord_features = tf.parse_single_example(tfrecord_serialized, 
         features={ 
          'label': tf.FixedLenFeature([], tf.string), 
          'image': tf.FixedLenFeature([], tf.string), 
         }, name='features') 

    image = tf.decode_raw(tfrecord_features['image'], tf.uint8) 
    image = tf.reshape(image, shape=(32, 32, 3)) 

    label = tf.cast(tfrecord_features['label'], tf.string) 

    #provide batches 
    images, labels = tf.train.shuffle_batch([image, label], 
      batch_size=100, 
      num_threads=4, 
      capacity=50, 
      min_after_dequeue=1) 

    return images, labels 

imgs, lbls = read_from_tfrecord([tfrecord_file]) 
init_op = tf.group(tf.global_variables_initializer(), 
     tf.local_variables_initializer()) 

with tf.Session() as sess: 
    sess.run(init_op) 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    while not coord.should_stop(): 
     labels, images = sess.run([lbls, imgs]) 
     print(images.shape) #PRINTED 20 TIMES BUT FAILED AT THE 21ST 
    coord.request_stop() 
    coord.join(threads) 

はここで、dummy.tfrecordファイルを生成するスクリプトです:

def generate_image_binary(): 
    images = np.random.randint(0,255, size=(NUM_POINTS, 32, 32, 3), 
      dtype=np.uint8) 
    labels = np.random.randint(0,2, size=(NUM_POINTS, 1)) 
    return labels, images 

def write_to_tfrecord(labels, images, tfrecord_file): 
    writer = tf.python_io.TFRecordWriter(tfrecord_file) 

    for i in range(NUM_POINTS): 
     example = tf.train.Example(features=tf.train.Features(feature={ 
        'label': 
        tf.train.Feature(bytes_list=tf.train.BytesList(value=[labels[i].tobytes()])), 
        'image': 
        tf.train.Feature(bytes_list=tf.train.BytesList(value=[images[i].tobytes()])) 
        })) 
     writer.write(example.SerializeToString()) 
    writer.close() 

tfrecord_file = 'dummy.tfrecord' 
labels, images= generate_image_binary() 
write_to_tfrecord(labels, images, tfrecord_file) 

答えて

0

tf.errors.OutOfRangeErrorようCoordinatorキャッチして処理することができます例外は、どのキューが閉じられたことを報告するために使用されます。私たちはその例外を発生させる持つことはできません

with tf.Session() as sess: 
sess.run(init_op) 
coord = tf.train.Coordinator() 
threads = tf.train.start_queue_runners(coord=coord) 
try: 
    while not coord.should_stop(): 
     labels, images = sess.run([lbls, imgs]) 
     print(images.shape) #PRINTED 20 TIMES BUT FAILED AT THE 21ST 
except Exception, e: 
    # When done, ask the threads to stop. 
    coord.request_stop(e) 

finally: 
    coord.request_stop() 
    # Wait for threads to finish. 
coord.join(threads) 
+0

: あなたは上記の例外を処理するために、あなたのコードの変更を行うことができますか? – user3813674

+0

はコードを編集し、例外を 'Coordinator'に報告します –

関連する問題