私はtf.train.string_input_producer
がどのように動作するかに疑問を抱いています。だから私は入力パラメータとしてfilename_listをstring_input_producer
に与えたとします。次に、ドキュメントhttps://www.tensorflow.org/programmers_guide/reading_dataに従って、FIFOQueue
が作成されます。ここでは、エポック番号を設定したり、ファイル名をシャッフルしたりできます。したがって、私の場合、私は4つのファイル名( "db1.tfrecords"、 "db2.tfrecords" ...)を持っています。そして、私はtf.train.batch
を使って画像のネットワークバッチを送りました。さらに、それぞれのファイル名/データベースには、1人の人物の画像セットが含まれています。 2番目のデータベースは2番目の人のためのデータベースです。これまでのところ私は、次のコードを持っている:オートエンコーダの出力で再構成画像を表示しようとするときテンソルフローのtf.train.string_input_producerでエポック番号を調べる
tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
(common + "P21_db.tfrecords")]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
は最後に、私は、私はからの画像を見る第一のデータベースから画像を最初に開始しました2番目のデータベースなどです。
私の質問:同じエポック内にいるかどうかはどうすれば分かりますか?そして、私が元気な時代の中にいるなら、どうすれば私が持っているすべてのfile_namesのイメージのバッチをマージできますか?
最後に、私は次のようにSession
内のローカル変数を評価することにより、エポックの値を印刷しようとした:
epoch_var = tf.local_variables()[0]
そして:
with tf.Session() as sess:
print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.
任意の助けをはるかに高く評価されています!
'tf.python_io.tf_record_iterator'を使ってレコードの数を数えることができ、バッチサイズを指定すると、現在のエポック番号を取得する必要があります。あなたの2番目の質問をしなかった。 –
@vijaym、これは私が求めていることではありません。私には 'tf.train.string_input_producer'があり、' tf.python_io.tf_record_iterator'はありません。 –