問題を提起するのが間違っている場合はお詫び申し上げます。私はKerasとPythonの初心者ですので、応答にはそのことを念頭に置いてください。Keras fit_generatorでCNNをバッチトレインする方法は?
私は入力として画像を取るCNNステアリングモデルを訓練しようとしています。これはかなり大きなデータセットなので、私はfit_generator()で動作するデータジェネレータを作成しました。このメソッドをバッチでトレーニングする方法はわかりませんので、ジェネレータがfit_generator()にバッチを返さなければならないと仮定しました。発電機は次のようになります
def gen(file_name, batchsz = 64):
csvfile = open(file_name)
reader = csv.reader(csvfile)
batchCount = 0
while True:
for line in reader:
inputs = []
targets = []
temp_image = cv2.imread(line[1]) # line[1] is path to image
measurement = line[3] # steering angle
inputs.append(temp_image)
targets.append(measurement)
batchCount += 1
if batchCount >= batchsz:
batchCount = 0
X = np.array(inputs)
y = np.array(targets)
yield X, y
csvfile.seek(0)
はテレメトリデータを含むCSVファイル(操舵角など)を読み取り、パス画像サンプルに、サイズの配列を返す:batchsz (fit_generatorする呼び出し)は次のようになります。
tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator
try:
model.fit_generator(
tgen,
samples_per_epoch=113526,
nb_epoch=6,
validation_data=vgen,
nb_val_samples=20001
)
データセットは、113526個のサンプルポイントを含む、まだモデル訓練更新出力は、(例えば)このように読める:
1020/113526 [..............................] - ETA: 27737s - loss: 0.0080
1021/113526 [..............................] - ETA: 27723s - loss: 0.0080
1022/113526 [..............................] - ETA: 27709s - loss: 0.0080
1023/113526 [..............................] - ETA: 27696s - loss: 0.0080
どのサンプルごとにトレーニングをしているようです(確率的に?)。 結果のモデルは役に立たない。以前は.fit()を使用して、メモリにロードされたデータセット全体を使って、はるかに小さなデータセットを訓練しました。明らかに、私のfit_generator()アプローチで何かが間違っています。これについての助けには非常に感謝します。
'samples_per_epoch'は、[keras documentation](https://keras.io/models/sequence/)で提案されているように、' total_samples/batch_size'と等しくなければなりません。 'samples_per_epoch'は、エポックが完了したとみなされる前にジェネレータが呼び出された回数を指定します。使用している' batch_size 'が何であるか分かりません。 – gionni
ありがとう@gionni。 Keras 1.0.2から最新に更新されました。 fit-generator()パラメータは、このバージョンではより意味があります。 – tinyMind