2017-08-01 13 views
0

私の質問は、複数の(またはシャードされた)tfrecordからバッチインプットを取得する方法です。私は例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410を読んだ。基本的なパイプラインは、例えば、(1)これらのファイル名から一連のtfrecords(例えば、train-000-of-005train-001-of-005、...)、(2)を生成し、リストを生成してtf.train.string_input_producer (3)同時にtf.RandomShuffleQueueを生成し、(4)tf.train.batch_joinを使用してバッチインプットを生成する。tfrecordsからバッチインプットを処理するもっと簡単な方法はありますか?

私はこれが複雑だと思うし、私はこの手順のロジックをよく分かりません。私の場合、.npyファイルのリストがあり、sharded tfrecords(単一の大きなファイルではなく、複数の別々のtfrecords)を生成したいと考えています。これらの.npyファイルのそれぞれには、異なる数の正と負のサンプル(2つのクラス)が含まれています。基本的な方法は、1つの大きなtfrecordファイルを生成することです。ファイルが大きすぎます(~20Gb)。だから私はシャードされたtfrecordsに頼っている。これを行う簡単な方法はありますか?ありがとう。

答えて

11

Dataset APIを使用するとすべてのプロセスが単純化されます。次の2つの部分があります:(1): Convert numpy array to tfrecords(2,3,4): read the tfrecords to generate batches numpyの配列からtfrecordsの

1. 作成:

def npy_to_tfrecords(...): 
     # write records to a tfrecords file 
     writer = tf.python_io.TFRecordWriter(output_file) 

     # Loop through all the features you want to write 
     for ... : 
      let say X is of np.array([[...][...]]) 
      let say y is of np.array[[0/1]] 

     # Feature contains a map of string to feature proto objects 
     feature = {} 
     feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten())) 
     feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=y)) 

     # Construct the Example proto object 
     example = tf.train.Example(features=tf.train.Features(feature=feature)) 

     # Serialize the example to a string 
     serialized = example.SerializeToString() 

     # write the serialized objec to the disk 
     writer.write(serialized) 
     writer.close() 

2. は、データセットのAPIを使用してtfrecordsを読む(tensorflow> = 1.2):

# Creates a dataset that reads all of the examples from filenames. 
    filenames = ["file1.tfrecord", "file2.tfrecord", ..."fileN.tfrecord"] 
    dataset = tf.contrib.data.TFRecordDataset(filenames) 

    # example proto decode 
    def _parse_function(example_proto): 
     keys_to_features = {'X':tf.FixedLenFeature((shape_of_npy_array), tf.float32), 
          'y': tf.FixedLenFeature((), tf.int64, default_value=0)} 
     parsed_features = tf.parse_single_example(example_proto, keys_to_features) 
    return parsed_features['X'], parsed_features['y'] 

    # Parse the record into tensors. 
    dataset = dataset.map(_parse_function) 

    # Shuffle the dataset 
    dataset = dataset.shuffle(buffer_size=10000) 

    # Repeat the input indefinitly 
    dataset = dataset.repeat() 

    # Generate batches 
    dataset = dataset.batch(batch_size) 

    # Create a one-shot iterator 
    iterator = dataset.make_one_shot_iterator() 

    # Get batch X and y 
    X, y = iterator.get_next() 
+0

ああ、Iあなたの詳細な答えを非常に感謝します!あなたは私の命を救う! – mining

+0

こんにちは、このAPIは 'tf.train.shuffle_batch' APIのような' num_threads'や 'capacity'をサポートしていますか?私の場合、ネットワークが小さい場合、GPUでの実行はデータのロードよりも速く、アイドル状態のGPU時間につながります。だから、私はデータをフェッチするためのキューには常にいっぱいです。ありがとう。 – mining

+2

チェック:https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map –

関連する問題