"新しい"(TensorFlow v1.4)データセットAPIを使用してTFRecordフォーマットイメージデータを読み取る際に問題があります。私は問題を読んでいるときに、私は何とか1つのバッチではなく、データセット全体を消費していると考えています。ここでは、バッチ/ファイルキューAPIを使用してこれを行う例を示します。https://github.com/gnperdue/TFExperiments/tree/master/conv(例では、分類子を実行していますが、TFRecordイメージを読み取るコードはDataReaders.py
クラスです)。新しいTensorFlowデータセットAPIを使用してTFRecordイメージデータを読み取る
問題の機能は、私は信じて、これらのです:
def parse_mnist_tfrec(tfrecord, features_shape):
tfrecord_features = tf.parse_single_example(
tfrecord,
features={
'features': tf.FixedLenFeature([], tf.string),
'targets': tf.FixedLenFeature([], tf.string)
}
)
features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
features = tf.reshape(features, features_shape)
features = tf.cast(features, tf.float32)
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
targets = tf.cast(targets, tf.float32)
return features, targets
class MNISTDataReaderDset:
def __init__(self, data_reader_dict):
# doesn't matter here
def batch_generator(self, num_epochs=1):
def parse_fn(tfrecord):
return parse_mnist_tfrec(
tfrecord, self.name, self.features_shape
)
dataset = tf.data.TFRecordDataset(
self.filenames_list, compression_type=self.compression_type
)
dataset = dataset.map(parse_fn)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(self.batch_size)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
次に、使用中:
[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
Input to reshape is a tensor with 50000 values, but the requested shape has 1
[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
誰もがいずれかを持っています:
batch_features, batch_labels = \
data_reader.batch_generator(num_epochs=1)
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
# look at 3 batches only
for _ in range(3):
labels, feats = sess.run([
batch_labels, batch_features
])
これは、のようなエラーが発生しアイデア?
私は読者の例では、完全なコードを要旨とここ(TFRecordの形で私たちの古い、良い友人MNIST、)TFRecordファイルへのリンクを持っている:
https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8
に感謝を!
編集 - 私もflat_map
を試してみました、例えば:
def batch_generator(self, num_epochs=1):
"""
TODO - we can use placeholders for the list of file names and
init with a feed_dict when we call `sess.run` - give this a
try with one list for training and one for validation
"""
def parse_fn(tfrecord):
return parse_mnist_tfrec(
tfrecord, self.name, self.features_shape
)
dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list)
dataset = dataset.flat_map(
lambda filename: (
tf.data.TFRecordDataset(
filename, compression_type=self.compression_type
).map(parse_fn).batch(self.batch_size)
)
)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
私はまた、1つのファイルだけではなく(上記の本に近づいての私の最初の方法で)リストを使用してみました。何があっても、TFはいつもファイル全体をTFRecordDataset
に食べたいと思うので、単一のレコードでは動作しません。