0
MNISTチュートリアルを変更して、TFRコードを奇妙なフォーマットではなくウェブからダウンロードする方法を教えてください。Tensorflow MNIST TFRecord
私は、入力モデルから200x200のRGBイメージを含むTFRecordsを作成し、これを1080Tiで訓練しようとしましたが、TFRecordsをロードして畳み込みニューラルネットワーク。
MNISTチュートリアルを変更して、TFRコードを奇妙なフォーマットではなくウェブからダウンロードする方法を教えてください。Tensorflow MNIST TFRecord
私は、入力モデルから200x200のRGBイメージを含むTFRecordsを作成し、これを1080Tiで訓練しようとしましたが、TFRecordsをロードして畳み込みニューラルネットワーク。
私はあなたがやっているのと同じことをしました。画像データを構築するために同じスクリプトを使用しました。データを読んで訓練する私のコードは
import tensorflow as tf
height = 28
width = 28
tfrecords_train_filename = 'train-00000-of-00001'
tfrecords_test_filename = 'test-00000-of-00001'
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/colorspace': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/channels': tf.FixedLenFeature([], tf.int64),
'image/class/label': tf.FixedLenFeature([], tf.int64),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/format': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='')
})
image_buffer = features['image/encoded']
image_label = tf.cast(features['image/class/label'], tf.int32)
# Decode the jpeg
with tf.name_scope('decode_jpeg', [image_buffer], None):
# decode
image = tf.image.decode_jpeg(image_buffer, channels=3)
# and convert to single precision data type
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.rgb_to_grayscale(image)
image_shape = tf.stack([height, width, 1])
image = tf.reshape(image, image_shape)
return image, image_label
def inputs(filename, batch_size, num_epochs):
if not num_epochs: num_epochs = None
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
image, label = read_and_decode(filename_queue)
# Shuffle the examples and collect them into batch_size batches.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
min_after_dequeue=1000)
return images, sparse_labels
image, label = inputs(filename=tfrecords_train_filename, batch_size=200, num_epochs=None)
image = tf.reshape(image, [-1, 784])
label = tf.one_hot(label - 1, 10)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
for i in range(1000):
img, lbl = sess.run([image, label])
sess.run(train_step, feed_dict={x: img, y_: lbl})
img, lbl = sess.run([image, label])
print(sess.run(accuracy, feed_dict={x: img, y_: lbl}))
coord.request_stop()
coord.join(threads)
mnistを分類するための超簡単なモデルです。しかし、私はそれがTFRecordファイルを使って訓練する方法のための拡張可能な答えだと思います。これは、評価データを考慮に入れていない。なぜなら、これにはさらに調整が必要だからである。
[このガイド](https://www.tensorflow.org/programmers_guide/datasets)には、TFRecordファイルおよびgtテンソルからデータをロードする方法の例があります。それでは、ネットワークが現時点でどのような入力でもなく、ネットワークに入力としてデータを渡すだけです。 – GPhilo
@GPhilo私は自分のデータセットを「images:Images。4Dテンソルのサイズで利用できます。」[batch_size、FLAGS.image_size、 image_size、3]。 ラベル:[FLAGS.batch_size]の1次元整数テンソル。 "しかし、読み込んだものを取る関数を持つtf.estimator.inputsは表示されません。 – Eejin
tf.estimator.inputsには、テンソル形式ではないデータをネットワーク上で変換できる便利な機能があります。 'input_fn'を書き直す必要があります。私はこの高レベルのAPIに精通していませんが、[Estimatorのドキュメント](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)から、あなたは 'input_fn'を定義する必要があると思います{{images}:your_image_tensor、 'labels':your_label_tensor} 'を返します。 – GPhilo