2017-12-09 8 views
0

in the tutorial from the TensorFlow guide siteの手順に従ってAlexNet CNNモデルをトレーニングしようとしていますが、チュートリアルでは以下のコードを使用してトレーニングデータにロードしますTensorFlow - TFレコードが大きすぎてnp配列に一度にロードできない

mnist = tf.contrib.learn.datasets.load_dataset("mnist") 
train_data = mnist.train.images # Returns np.array 
train_labels = np.asarray(mnist.train.labels, dtype=np.int32) 
eval_data = mnist.test.images # Returns np.array 
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32) 

私にとっては、データセットの例をTFRecordファイルに書き込んだ後、トレーニング中にこれらのレコードを読み込んで、それをalexnetネットワークにフィードしてみてください。以下のコードを参照してください:

#FUNCTION TO GET ALL DATASET DATA 
def _read_multiple_images(filenames, perform_shuffle=False, repeat_count=1, 
batch_size=1, available_record=39209, num_of_epochs=1): 
    def _read_one_image(serialized): 
     #Specify the fatures you want to extract 
     features = {'image/shape': tf.FixedLenFeature([], tf.string), 
      'image/class/label': tf.FixedLenFeature([], tf.int64), 
      'image/class/text': tf.FixedLenFeature([], tf.string), 
      'image/filename': tf.FixedLenFeature([], tf.string), 
      'image/encoded': tf.FixedLenFeature([], tf.string)} 
     parsed_example = tf.parse_single_example(serialized, 
     features=features) 

     #Finese extracted data 
     image_raw = tf.decode_raw(parsed_example['image/encoded'], tf.uint8) 
     shape = tf.decode_raw(parsed_example['image/shape'], tf.int32) 
     label = tf.cast(parsed_example['image/class/label'], dtype=tf.int32) 
     reshaped_img = tf.reshape(image_raw, shape) 
     casted_img = tf.cast(reshaped_img, tf.float32) 
     label_tensor= [label] 
     image_tensor = [casted_img] 
     return label_tensor, image_tensor 

complete_labels = np.array([]) 
complete_images = np.array([]) 

dataset = tf.data.TFRecordDataset(filenames=filenames) 
dataset = dataset.map(_read_one_image) 
dataset = dataset.repeat(repeat_count)  #Repeats dataset this # times 
dataset = dataset.batch(batch_size)   #Batch size to use 
iterator = dataset.make_initializable_iterator() 
labels_tensor, images_tensor = iterator.get_next() #Get batch data 
no_of_rounds = int(math.ceil(available_record/batch_size)); 

#Create tf session, get nest set of batches, and evelauate them in batches 
sess = tf.Session() 
count=1 
for _ in range(num_of_epochs): 
    sess.run(iterator.initializer) 

    while True: 
    try: 
     evaluated_label, evaluated_image = sess.run([labels_tensor, 
     images_tensor]) 

     #convert evaluated tensors to np array 
     label_np_array = np.asarray(evaluated_label, dtype=np.uint8) 
     image_np_array = np.asarray(evaluated_image, dtype=np.uint8) 

     #squeeze np array to make dimesnsions appropriate 
     squeezed_label_np_array = label_np_array.squeeze() 
     squeezed_image_np_array = image_np_array.squeeze() 

     #add current batch to total 
     complete_labels = np.append(complete_labels, squeezed_label_np_array) 
     complete_images = np.append(complete_images, squeezed_image_np_array) 
     except tf.errors.OutOfRangeError: 
     print("End of Dataset Reached") 
     break 
    count=count+1 

sess.close() 
return complete_labels, complete_images 

私の主な問題は、私は私のTF推定に送り込むことができるように、NP配列としての私のデータセット(227x227x3)内のすべての39209枚の画像をリカするトリングながら、ということです。コンピュータのメモリが不足しています。

train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": 
complete_images},y=complete_labels,batch_size=100,num_epochs=1, 
shuffle=True) 
dataset_classifier.train(input_fn=train_input_fn,num_epochs=1,hooks= 
[logging_hook]) 

私はバッチで私のTFレコードのうち、私のイメージとラベルを取得し、むしろとしてNP配列にすべてをロードする必要がすることよりも、バッチで私TF.Estimatorにそれを養うことができます方法はありますあなたがtf.data.Datasetとして、あなたのデータにアクセスできる場合はin this tutorial

答えて

3

指定し、Estimatorに渡す前にnumpyの配列に変換する必要はありません。それは一度にメモリ内のデータセット全体をマテリアライズする必要がなくなりので、これははるかに効率的numpyの配列を構築するよりもする必要があります

def train_input_fn(): 
    dataset = tf.data.TFRecordDataset(filenames=filenames) 
    dataset = dataset.map(_read_one_image) 
    dataset = dataset.repeat(1) # Because `num_epochs=1`. 
    dataset = dataset.batch(100) # Because `batch_size=1`. 

    dataset = dataset.prefetch(1) # To improve performance by overlapping execution. 

    iterator = dataset.make_one_shot_iterator() # NOTE: Use a "one-shot" iterator. 
    labels_tensor, images_tensor = iterator.get_next() 

    return {"x": images_tensor}, labels_tensor 

dataset_classifier.train(
    input_fn=train_input_fn, num_epochs=1, hooks=[logging_hook]) 

:あなたは、単に以下のようなもので、あなたの入力機能に直接Datasetを構築することができます。 Dataset.prefetch()のようなパフォーマンス拡張とDataset.map()のパラレルバージョンを使用してトレーニングの速度を向上させることもできます。

関連する問題