2017-11-27 13 views
3

"新しい"(TensorFlow v1.4)データセットAPIを使用してTFRecordフォーマットイメージデータを読み取る際に問題があります。私は問題を読んでいるときに、私は何とか1つのバッチではなく、データセット全体を消費していると考えています。ここでは、バッチ/ファイルキューAPIを使用してこれを行う例を示します。https://github.com/gnperdue/TFExperiments/tree/master/conv(例では、分類子を実行していますが、TFRecordイメージを読み取るコードはDataReaders.pyクラスです)。新しいTensorFlowデータセットAPIを使用してTFRecordイメージデータを読み取る

問題の機能は、私は信じて、これらのです:

def parse_mnist_tfrec(tfrecord, features_shape): 
    tfrecord_features = tf.parse_single_example(
     tfrecord, 
     features={ 
      'features': tf.FixedLenFeature([], tf.string), 
      'targets': tf.FixedLenFeature([], tf.string) 
     } 
    ) 
    features = tf.decode_raw(tfrecord_features['features'], tf.uint8) 
    features = tf.reshape(features, features_shape) 
    features = tf.cast(features, tf.float32) 
    targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8) 
    targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0) 
    targets = tf.cast(targets, tf.float32) 
    return features, targets 

class MNISTDataReaderDset: 
    def __init__(self, data_reader_dict): 
     # doesn't matter here 

    def batch_generator(self, num_epochs=1): 
     def parse_fn(tfrecord): 
      return parse_mnist_tfrec(
       tfrecord, self.name, self.features_shape 
      ) 
     dataset = tf.data.TFRecordDataset(
      self.filenames_list, compression_type=self.compression_type 
     ) 
     dataset = dataset.map(parse_fn) 
     dataset = dataset.repeat(num_epochs) 
     dataset = dataset.batch(self.batch_size) 
     iterator = dataset.make_one_shot_iterator() 
     batch_features, batch_labels = iterator.get_next() 
     return batch_features, batch_labels 

次に、使用中:

[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]] 
Input to reshape is a tensor with 50000 values, but the requested shape has 1 
[[Node: Reshape_1 = Reshape[T=DT_UINT8, Tshape=DT_INT32](DecodeRaw_1, Reshape_1/shape)]] 
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,28,28,1], [?,10]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]] 

誰もがいずれかを持っています:

 batch_features, batch_labels = \ 
      data_reader.batch_generator(num_epochs=1) 

     sess.run(tf.local_variables_initializer()) 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(coord=coord) 
     try: 
      # look at 3 batches only 
      for _ in range(3): 
       labels, feats = sess.run([ 
        batch_labels, batch_features 
       ]) 

これは、のようなエラーが発生しアイデア?

私は読者の例では、完全なコードを要旨とここ(TFRecordの形で私たちの古い、良い友人MNIST、)TFRecordファイルへのリンクを持っている:

https://gist.github.com/gnperdue/56092626d611ae23370a21fdeeb2abe8

に感謝を!

編集 - 私もflat_mapを試してみました、例えば:

def batch_generator(self, num_epochs=1): 
    """ 
    TODO - we can use placeholders for the list of file names and 
    init with a feed_dict when we call `sess.run` - give this a 
    try with one list for training and one for validation 
    """ 
    def parse_fn(tfrecord): 
     return parse_mnist_tfrec(
      tfrecord, self.name, self.features_shape 
     ) 
    dataset = tf.data.Dataset.from_tensor_slices(self.filenames_list) 
    dataset = dataset.flat_map(
     lambda filename: (
      tf.data.TFRecordDataset(
       filename, compression_type=self.compression_type 
      ).map(parse_fn).batch(self.batch_size) 
     ) 
    ) 
    dataset = dataset.repeat(num_epochs) 
    iterator = dataset.make_one_shot_iterator() 
    batch_features, batch_labels = iterator.get_next() 
    return batch_features, batch_labels 

私はまた、1つのファイルだけではなく(上記の本に近づいての私の最初の方法で)リストを使用してみました。何があっても、TFはいつもファイル全体をTFRecordDatasetに食べたいと思うので、単一のレコードでは動作しません。

答えて

0

これはわかりました。上のコードは問題ありません。問題は、TFRecordsを作成するためのスクリプトでした。基本的に、私はこの

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file): 
    writer = tf.python_io.TFRecordWriter(tfrecord_file) 
    tfeat, ttarg = get_binary_data(reader, start_idx, stop_idx) 
    example = tf.train.Example(
     features=tf.train.Features(
      feature={ 
       'features': tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tfeat]) 
       ), 
       'targets': tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[ttarg]) 
       ) 
      } 
     ) 
    ) 
    writer.write(example.SerializeToString()) 
    writer.close() 

のようなブロックを持っていたし、私はこの代わりのようなブロックに必要な:

def write_tfrecord(reader, start_idx, stop_idx, tfrecord_file): 
    writer = tf.python_io.TFRecordWriter(tfrecord_file) 
    for idx in range(start_idx, stop_idx): 
     tfeat, ttarg = get_binary_data(reader, idx) 
     example = tf.train.Example(
      features=tf.train.Features(
       feature={ 
        'features': tf.train.Feature(
         bytes_list=tf.train.BytesList(value=[tfeat]) 
        ), 
        'targets': tf.train.Feature(
         bytes_list=tf.train.BytesList(value=[ttarg]) 
        ) 
       } 
      ) 
     ) 
     writer.write(example.SerializeToString()) 
    writer.close() 

言うことです - 私は基本的に1つの巨大なTFRecordときのようにデータの私の全体のブロックを書いていたが私は、データに1つの例を1つ作成する必要がありました。

古いファイルとバッチキューAPIのいずれかで動作させると、tf.train.batchのような機能は、巨大なブロックを彫るか、あなたが何を与えるかに応じて、バッチにレコードを記録します。 TFRecordsファイルを作成したコードを修正したとき、古いファイルとバッチキューコードで何も変更する必要はなく、それでもTFRecordsファイルを使い切ってしまった。ただし、Dataset APIはこの違いに敏感です。そのため、上記の私のコードでは、ファイル全体が本当に1つの大きなTFRecordだったので、常にファイル全体を消費するように見えました。

関連する問題