私は最近、ケラ用発電機で遊んでいました。私はついに準備と例を管理しました。それはランダムなデータを使っているので、NNを教えることは妥当ではありませんが、kerasのためにジェネレータを使う良い例です。
生成するいくつかのデータ
import numpy as np
import pandas as pd
data = np.random.rand(200,2)
expected = np.random.randint(2, size=200).reshape(-1,1)
dataFrame = pd.DataFrame(data, columns = ['a','b'])
expectedFrame = pd.DataFrame(expected, columns = ['expected'])
dataFrameTrain, dataFrameTest = dataFrame[:100],dataFrame[-100:]
expectedFrameTrain, expectedFrameTest = expectedFrame[:100],expectedFrame[-100:]
ジェネレータ
def generator(X_data, y_data, batch_size):
samples_per_epoch = X_data.shape[0]
number_of_batches = samples_per_epoch/batch_size
counter=0
while 1:
X_batch = np.array(X_data[batch_size*counter:batch_size*(counter+1)]).astype('float32')
y_batch = np.array(y_data[batch_size*counter:batch_size*(counter+1)]).astype('float32')
counter += 1
yield X_batch,y_batch
#restart counter to yeild data in the next epoch as well
if counter <= number_of_batches:
counter = 0
Kerasモデル
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten, Reshape
from keras.layers.convolutional import Convolution1D, Convolution2D, MaxPooling2D
from keras.utils import np_utils
model = Sequential()
model.add(Dense(12, activation='relu', input_dim=dataFrame.shape[1]))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
#Train the model using generator vs using full batch
batch_size = 8
model.fit_generator(generator(dataFrameTrain,expectedFrameTrain,batch_size), epochs=3,steps_per_epoch = dataFrame.shape[0]/batch_size, validation_data=generator(dataFrameTest,expectedFrameTest,batch_size*2),validation_steps=dataFrame.shape[0]/batch_size*2)
#without generator
#model.fit(x = np.array(dataFrame), y = np.array(expected), batch_size = batch_size, epochs = 3)
出力
Epoch 1/3
25/25 [==============================] - 3s - loss: 0.7297 - acc: 0.4750 -
val_loss: 0.7183 - val_acc: 0.5000
Epoch 2/3
25/25 [==============================] - 0s - loss: 0.7213 - acc: 0.3750 -
val_loss: 0.7117 - val_acc: 0.5000
Epoch 3/3
25/25 [==============================] - 0s - loss: 0.7132 - acc: 0.3750 -
val_loss: 0.7065 - val_acc: 0.5000
これは役に立ちますか? http://anandology.com/python-practice-book/iterators.html – Rob