2017-06-15 13 views
0

現在、深いconvNetを実装することで、kaggleのcats vs dogs分類タスクに取り組んでいます。次のコード行は、データの前処理のために使用されます。Pythonでカスタムデータのnext_batch()関数を実装する方法

def label_img(img): 
    word_label = img.split('.')[-3] 
    if word_label == 'cat': return [1,0] 
    elif word_label == 'dog': return [0,1] 

def create_train_data(): 
    training_data = [] 
    for img in tqdm(os.listdir(TRAIN_DIR)): 
     label = label_img(img) 
     path = os.path.join(TRAIN_DIR,img) 
     img = cv2.resize(cv2.imread(path,cv2.IMREAD_GRAYSCALE),IMG_SIZE,IMG_SIZE)) 
     training_data.append([np.array(img),np.array(label)]) 

    shuffle(training_data) 
    return training_data 

train_data = create_train_data() 

X_train = np.array([i[0] for i in train_data]).reshape(-1, IMG_SIZE,IMG_SIZE,1) 
Y_train =np.asarray([i[1] for i in train_data]) 

は私がtensorflow深いMNISTチュートリアル

batch = mnist.train.next_batch(100) 

答えて

0

このcodeで提供される次の機能を複製する機能を実装したいの良い例です。バッチを生成する関数を考え出してください。

簡単に説明すると、あなただけのx_trainための2つの配列を考え出す必要があり、好きy_train:よう

batch_inputs = np.ndarray(shape=(batch_size), dtype=np.int32) 
    batch_labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 

とSET電車データ:

batch_inpouts[i] = ... 
    batch_labels[i, 0] = ... 

は最後のセッションにデータセットを渡します:

_, loss_val = session.run([optimizer, loss], feed_dict={train_inputs: batch_inputs, train_labels:batch_labels}) 
+0

これを試してみます。御時間ありがとうございます。 –

2

ランダムにデータを並べ替えることもできます各バッチ。

EPOCH = 100 
BATCH_SIZE = 128 
TRAIN_DATASIZE,_,_,_ = X_train.shape 
PERIOD = TRAIN_DATASIZE/BATCH_SIZE #Number of iterations for each epoch 

for e in range(EPOCH): 
    idxs = numpy.random.permutation(TRAIN_DATASIZE) #shuffled ordering 
    X_random = X_train[idxs] 
    Y_random = Y_train[idxs] 
    for i in range(PERIOD): 
     batch_X = X_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     batch_Y = Y_random[i * BATCH_SIZE:(i+1) * BATCH_SIZE] 
     sess.run(train,feed_dict = {X: batch_X, Y:batch_Y}) 
+0

ありがとうございます。最後に、ネットワークを正しく鍛えることができます。 –

+0

テンソルフローのnext_batch()が返すものについて私に教えてもらえますか?指定されたバッチサイズのトレーニングセットからのデータのランダムなコレクションですか?もしそうなら、それは非繰り返しを確実にしますか? @Joshua Lim –

+0

next_batch()は、tensorflowによって提供されるMNISTチュートリアル専用の関数です。どのように機能するかは、最初にトレーニング画像とラベルのペアをランダム化し、関数が呼び出されるたびに後続の100個の各画像を選択することです。終了に達すると、画像ラベル対が再びランダム化され、処理が繰り返される。使用可能なすべてのペアが使用されると、データセット全体が再シャッフルされ、繰り返されます。 –

関連する問題