2016-09-04 8 views
2

私はTensorFlowクリーンな方法(tf.train.shuffle_batch)を使用して入力データを処理しようとしています。このコードのほとんどは、decode_jpeg関数のようなわずかな変更でチュートリアルから集めたものです。私はtf.train.shuffle_batchが私のために働いていない

batch_x, batch_y = inputs(True, 100,100) 

を実行しようとすると、私は次のエラーを取得する

size = 32,32 
classes = 43 
train_size = 12760 
batch_size = 100 
max_steps = 10000 

def read_and_decode(filename_queue): 
    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'image/encoded': tf.FixedLenFeature([], tf.string), 
      'image/class/label': tf.FixedLenFeature([], tf.int64), 
      'image/height': tf.FixedLenFeature([], tf.int64), 
      'image/width': tf.FixedLenFeature([], tf.int64), 
     }) 
    label = tf.cast(features['image/class/label'], tf.int32) 
    reshaped_image = tf.image.decode_jpeg(features['image/encoded']) 
    reshaped_image = tf.image.resize_images(reshaped_image, size[0], size[1], method = 0) 
    reshaped_image = tf.image.per_image_whitening(reshaped_image) 
    return reshaped_image, label 

def inputs(train, batch_size, num_epochs): 
    subset = "train" 
    tf_record_pattern = os.path.join(FLAGS.train_dir + '/GTSRB', '%s-*' % subset) 
    data_files = tf.gfile.Glob(tf_record_pattern) 
    filename_queue = tf.train.string_input_producer(
     data_files, num_epochs=num_epochs) 

    # Even when reading in multiple threads, share the filename 
    # queue. 
    image, label = read_and_decode(filename_queue) 

    # Shuffle the examples and collect them into batch_size batches. 
    # (Internally uses a RandomShuffleQueue.) 
    # We run this in two threads to avoid being a bottleneck. 
    images, sparse_labels = tf.train.shuffle_batch(
     [image, label], batch_size=batch_size, num_threads=2, 
     capacity=1000 + 3 * batch_size, 
     # Ensures a minimum amount of shuffling of examples. 
     min_after_dequeue=1000) 
    return images, sparse_labels 

--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
<ipython-input-6-543290a0c903> in <module>() 
----> 1 batch_x, batch_y = inputs(True, 100,100) 

<ipython-input-5-a8c07c7fc263> in inputs(train, batch_size, num_epochs) 
    73   capacity=1000 + 3 * batch_size, 
    74   # Ensures a minimum amount of shuffling of examples. 
---> 75   min_after_dequeue=1000) 
    76  #return image, label 
    77  return images, sparse_labels 

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/training/input.pyc in shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads, seed, enqueue_many, shapes, allow_smaller_final_batch, shared_name, name) 
    800  queue = data_flow_ops.RandomShuffleQueue(
    801   capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, 
--> 802   dtypes=types, shapes=shapes, shared_name=shared_name) 
    803  _enqueue(queue, tensor_list, num_threads, enqueue_many) 
    804  full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), 

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.pyc in __init__(self, capacity, min_after_dequeue, dtypes, shapes, names, seed, shared_name, name) 
    580  """ 
    581  dtypes = _as_type_list(dtypes) 
--> 582  shapes = _as_shape_list(shapes, dtypes) 
    583  names = _as_name_list(names, dtypes) 
    584  # If shared_name is provided and an op seed was not provided, we must ensure 

/Users/Kevin/tensorflow/lib/python2.7/site-packages/tensorflow/python/ops/data_flow_ops.pyc in _as_shape_list(shapes, dtypes, unknown_dim_allowed, unknown_rank_allowed) 
    70 if not unknown_dim_allowed: 
    71  if any([not shape.is_fully_defined() for shape in shapes]): 
---> 72  raise ValueError("All shapes must be fully defined: %s" % shapes) 
    73 if not unknown_rank_allowed: 
    74  if any([shape.dims is None for shape in shapes]): 

ValueError: All shapes must be fully defined: [TensorShape([Dimension(32), Dimension(32), Dimension(None)]), TensorShape([])] 

私はこのエラーを引き起こしているかわからないんだけど、私はそれが何かを持っている想像彼らは3つのチャネル(RGB)を持つ必要があるときに彼らは次元を持っていないことを示すので、私は私のイメージを処理している方法で。

答えて

3

batching methods in TensorFlowtf.train.batch()tf.train.batch_join()tf.train.shuffle_batch()、及びtf.train.shuffle_batch_join())は、それらが密なテンソルに充填することができるように、バッチのすべての要素が、全く同じ形状*を有することを必要とします。あなたのコードでは、tf.train.shuffle_batch()に渡すテンソルimageの3番目の次元は不明なサイズであるようです。これは、各画像のチャンネル数に対応します。モノクロ画像の場合は1、カラー画像の場合は3、アルファチャンネルのカラー画像の場合は4です。 channels=NNが1、3、または4のいずれかである場合)を明示的に渡すと、TensorFlowは画像テンソルの形状に関する十分な情報を取得します。


  * 1つの例外を除いて:あなたはtf.train.batch()またはtf.train.batch_join()dynamic_pad=Trueを渡すときの要素は異なる形状を持つことができますが、彼らは同じランクを持っている必要があります。一般に、これは画像データではなく連続したデータに対してのみ使用されます(画像のエッジで望ましくない動作をする)。

+0

私は自分のコードにサイズを追加しました。それは32,32という静的な値です。 – Kevin

+0

ああ、それは未知のチャンネルの数のようです。更新された回答をご覧ください。 – mrry

+0

ありがとう、decode_jpegのチャンネルを入れて私の問題を解決! – Kevin

関連する問題