2017-03-07 7 views
1

画像分類モデルにtf.contrib.learn.Estimatorで必要とされるinput_fnを構築する方法の例はありますか?私の画像は複数のTFRecordsファイルに保存されています。TFRecordsファイルに格納された画像を使用してEstimatorのinput_fnを構築する方法

tf.contrib.learn.read_batch_record_featuresを使用して、私はコード化されたイメージ文字列のバッチを生成することができます。しかし、これらの文字列を画像に変換する簡単な方法はありません。

+0

参考cifar10例えば、具体的には入力の実装:https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py –

答えて

0

あなたはtrain.tfrecordstest.tfrecordsに保存されているmnistfashion-mnistデータセットについては、以下のようなものを使用することができますhere参照。

tfrecordsへの変換はコードhereで行われ、元の画像とラベルを取得するにはパーサーが必要です。

def parser(serialized_example): 
    """Parses a single tf.Example into image and label tensors.""" 
    features = tf.parse_single_example(
     serialized_example, 
     features={ 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'label': tf.FixedLenFeature([], tf.int64), 
     }) 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    image.set_shape([28 * 28]) 

    # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5] 
    image = tf.cast(image, tf.float32)/255 - 0.5 
    label = tf.cast(features['label'], tf.int32) 
    return image, label 

残りの部分は簡単ですパーサを持った後、あなたはちょうどあなたが出力として画像とラベルを取得しますのでTFRecordDataset(train_filenames)を呼び出してから、各要素にパーサ機能をマップする必要があります。

# Keep list of filenames, so you can input directory of tfrecords easily 
training_filenames = ["data/train.tfrecords"] 
test_filenames = ["data/test.tfrecords"] 

# Define the input function for training 
def train_input_fn(): 
    # Import MNIST data 
    dataset = tf.contrib.data.TFRecordDataset(train_filenames) 

    # Map the parser over dataset, and batch results by up to batch_size 
    dataset = dataset.map(parser, num_threads=1, output_buffer_size=batch_size) 
    dataset = dataset.batch(batch_size) 
    dataset = dataset.repeat() 
    iterator = dataset.make_one_shot_iterator() 

    features, labels = iterator.get_next() 

    return features, labels 
関連する問題