2017-08-15 3 views
1

私はtf.train.string_input_producerがどのように動作するかに疑問を抱いています。だから私は入力パラメータとしてfilename_listをstring_input_producerに与えたとします。次に、ドキュメントhttps://www.tensorflow.org/programmers_guide/reading_dataに従って、FIFOQueueが作成されます。ここでは、エポック番号を設定したり、ファイル名をシャッフルしたりできます。したがって、私の場合、私は4つのファイル名( "db1.tfrecords"、 "db2.tfrecords" ...)を持っています。そして、私はtf.train.batchを使って画像のネットワークバッチを送りました。さらに、それぞれのファイル名/データベースには、1人の人物の画像セットが含まれています。 2番目のデータベースは2番目の人のためのデータベースです。これまでのところ私は、次のコードを持っている:オートエンコーダの出力で再構成画像を表示しようとするときテンソルフローのtf.train.string_input_producerでエポック番号を調べる

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"), 
          (common + "P21_db.tfrecords")] 

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue') 
reader = tf.TFRecordReader() 

key, serialized_example = reader.read(filename_queue) 
features = tf.parse_single_example(
    serialized_example, 
    # Defaults are not specified since both keys are required. 
    features={ 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'annotation_raw': tf.FixedLenFeature([], tf.string) 
    }) 

image = tf.decode_raw(features['image_raw'], tf.uint8) 
height = tf.cast(features['height'], tf.int32) 
width = tf.cast(features['width'], tf.int32) 

image = tf.reshape(image, [height, width, 3]) 

annotation = tf.cast(features['annotation_raw'], tf.string) 

min_after_dequeue = 100 
num_threads = 4 
capacity = min_after_dequeue + num_threads * batch_size 
label_batch, images_batch = tf.train.batch([annotation, image], 
                 shapes=[[], [112, 112, 3]], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 num_threads=num_threads) 

は最後に、私は、私はからの画像を見る第一のデータベースから画像を最初に開始しました2番目のデータベースなどです。

私の質問:同じエポック内にいるかどうかはどうすれば分かりますか?そして、私が元気な時代の中にいるなら、どうすれば私が持っているすべてのfile_namesのイメージのバッチをマージできますか?

最後に、私は次のようにSession内のローカル変数を評価することにより、エポックの値を印刷しようとした:

epoch_var = tf.local_variables()[0] 

そして:

with tf.Session() as sess: 
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y. 

任意の助けをはるかに高く評価されています!

+0

'tf.python_io.tf_record_iterator'を使ってレコードの数を数えることができ、バッチサイズを指定すると、現在のエポック番号を取得する必要があります。あなたの2番目の質問をしなかった。 –

+0

@vijaym、これは私が求めていることではありません。私には 'tf.train.string_input_producer'があり、' tf.python_io.tf_record_iterator'はありません。 –

答えて

0

だから私が考えたのは、tf.train.shuffle_batch_joinを使用すると、別のデータセットから画像をシャッフルするときに問題が解決するということです。つまり、すべてのバッチにすべてのデータセット/ファイル名の画像が含まれています。次に例を示します。

def read_my_file_format(filename_queue): 
    reader = tf.TFRecordReader() 
    key, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'annotation_raw': tf.FixedLenFeature([], tf.string) 
     }) 

    # This is how we create one example, that is, extract one example from the database. 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    # The height and the weights are used to 
    height = tf.cast(features['height'], tf.int32) 
    width = tf.cast(features['width'], tf.int32) 

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the 
    # height and the weight to restore the original image back. 
    image = tf.reshape(image, [height, width, 3]) 

    annotation = tf.cast(features['annotation_raw'], tf.string) 
    return annotation, image 

def input_pipeline(filenames, batch_size, num_threads, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False, 
                name='queue') 
    # Therefore, Note that here we have created num_threads readers to read from the filename_queue. 
    example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)] 
    min_after_dequeue = 100 
    capacity = min_after_dequeue + num_threads * batch_size 
    label_batch, images_batch = tf.train.shuffle_batch_join(example_list, 
                  shapes=[[], [112, 112, 3]], 
                  batch_size=batch_size, 
                  capacity=capacity, 
                  min_after_dequeue=min_after_dequeue) 
    return label_batch, images_batch, example_list 

label_batch, images_batch, input_ann_img = \ 
    input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch) 

そして今、これはFIFOQueueから読み取るために、読者の数を作成するために起こっている、それぞれのリーダーの後に別のデコーダを持つことになります。最後に、画像をデコードした後、tf.train.shuffle_batch_joinを呼び出して作成した別のQueueに画像のバッチを送ります。

関連する問題