CNNに画像を読み取るためにパイプラインを使用しようとしています。私はstring_input_producer()
を使用してファイル名のキューを取得しましたが、何もせずにそこにハングアップしているようです。以下は私のコードです、私はそれを動作させる方法のアドバイスをしてください。TensorFlow画像読み取りキューが空です。
def read_image_file(filename_queue, labels):
reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)
image = tf.image.decode_png(value, channels=3)
image = tf.cast(image, tf.float32)
resized_image = tf.image.resize_images(image, [224, 112])
with tf.Session() as sess:
label = getLabel(labels, key.eval())
return resized_image, label
def input_pipeline(filename_queue, queue_names, batch_size, num_epochs, labels):
image, label = read_image_file(filename_queue, labels)
min_after_dequeue = 10 * batch_size
capacity = 20 * batch_size
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=1, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return image_batch, label_batch
train_queue = tf.train.string_input_producer(trainnames, shuffle=True, num_epochs=epochs)
train_batch, train_label = input_pipeline(train_queue, trainnames, batch_size, epochs, labels)
prediction = AlexNet(x)
#Training
with tf.name_scope("cost_function") as scope:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=train_label, logits=prediction(train_batch)))
tf.summary.scalar("cost_function", cost)
train_step = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(cost)
#Accuracy
with tf.name_scope("accuracy") as scope:
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar("accuracy", accuracy)
merged = tf.summary.merge_all()
#Session
with tf.Session() as sess:
print('started')
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord, start=True)
sess.run(threads)
try:
for step in range(steps_per_epch * epochs):
print('step: %d' %step)
sess.run(train_step)
except tf.errors.OutOfRangeError as ex:
pass
coord.request_stop()
coord.join(threads)
私のコードで 'getLabel'を定義しましたが、ここでは付けなかったので、基本的にファイル名(文字列)からラベルを抽出しますが、キーはテンソルです。そこで私は 'key.eval()'を実行してファイル名の文字列を取得しました。今はうまくいかないようですが、テンソルから文字列を取得する別の方法はありますか? – ALeex
おそらくすべての文字列演算を文字列テンソル演算で置き換える必要があるため、グラフの一部になり、実行時に実行されます。 – npf
https://www.tensorflow.org/api_guides/python/string_ops – npf