2017-08-24 19 views
1

問題を提起するのが間違っている場合はお詫び申し上げます。私はKerasとPythonの初心者ですので、応答にはそのことを念頭に置いてください。Keras fit_generatorでCNNをバッチトレインする方法は?

私は入力として画像を取るCNNステアリングモデルを訓練しようとしています。これはかなり大きなデータセットなので、私はfit_generator()で動作するデータジェネレータを作成しました。このメソッドをバッチでトレーニングする方法はわかりませんので、ジェネレータがfit_generator()にバッチを返さなければならないと仮定しました。発電機は次のようになります

def gen(file_name, batchsz = 64): 
    csvfile = open(file_name) 
    reader = csv.reader(csvfile) 
    batchCount = 0 
    while True: 
     for line in reader: 
      inputs = [] 
      targets = [] 
      temp_image = cv2.imread(line[1]) # line[1] is path to image 
      measurement = line[3] # steering angle 
      inputs.append(temp_image) 
      targets.append(measurement) 
      batchCount += 1 
      if batchCount >= batchsz: 
       batchCount = 0 
       X = np.array(inputs) 
       y = np.array(targets) 
       yield X, y 
     csvfile.seek(0) 

はテレメトリデータを含むCSVファイル(操舵角など)を読み取り、パス画像サンプルに、サイズの配列を返す:batchsz (fit_generatorする呼び出し)は次のようになります。

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator 
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator 
try: 
    model.fit_generator(
     tgen, 
     samples_per_epoch=113526, 
     nb_epoch=6, 
     validation_data=vgen, 
     nb_val_samples=20001 
    ) 

データセットは、113526個のサンプルポイントを含む、まだモデル訓練更新出力は、(例えば)このように読める:

1020/113526 [..............................] - ETA: 27737s - loss: 0.0080 
    1021/113526 [..............................] - ETA: 27723s - loss: 0.0080 
    1022/113526 [..............................] - ETA: 27709s - loss: 0.0080 
    1023/113526 [..............................] - ETA: 27696s - loss: 0.0080 

どのサンプルごとにトレーニングをしているようです(確率的に?)。 結果のモデルは役に立たない。以前は.fit()を使用して、メモリにロードされたデータセット全体を使って、はるかに小さなデータセットを訓練しました。明らかに、私のfit_generator()アプローチで何かが間違っています。これについての助けには非常に感謝します。

+1

'samples_per_epoch'は、[keras documentation](https://keras.io/models/sequence/)で提案されているように、' total_samples/batch_size'と等しくなければなりません。 'samples_per_epoch'は、エポックが完了したとみなされる前にジェネレータが呼び出された回数を指定します。使用している' batch_size 'が何であるか分かりません。 – gionni

+0

ありがとう@gionni。 Keras 1.0.2から最新に更新されました。 fit-generator()パラメータは、このバージョンではより意味があります。 – tinyMind

答えて

2

この:

for line in reader: 
    inputs = [] 
    targets = [] 

...は、CSVファイル内のすべての行のバッチをリセットしています。あなたはあなたの全体のデータで訓練していないが、128

提案でただ一つのサンプルで:誰かがコメントしたよう

for line in reader: 

    if batchCount == 0: 
     inputs = [] 
     targets = [] 
    .... 
    .... 

、フィット発電機で、samples_per_epochtotal_samples/batchszに等しくなければなりません

とにかく、あなたの損失はとにかく下がるはずです。そうでない場合は、おそらくデータをロードする方法や、モデルの初期化や構造で、コードに別の問題が残っている可能性があります。

は、あなたのイメージをプロットし、発電機のデータを印刷してみてください。

for X,y in tgen: #careful, this is an infinite loop, make it stop 

    print(X.shape[0]) # is this really the number of batches you expect? 

    for image in X: 
     ...some method to plot X so you can see it, or just print  

    print(y) 

確認が得られた値は、あなたは彼らがあることを期待するものでOKであれば。

+0

"... csvファイルのすべての行についてバッチをリセットしています。"ドゥー!それを見つけたはずです。 Wierdは、配列をプリントアウトするためのテストコードをいくつか持っているので、正しいサイズとシーケンスのバッチです。 – tinyMind

+0

損失については、最近、私は "凍った"損失に問題があった。私は多くのエポックで何度も1つのサンプルを何度も練習することにしました。突然その損失が減りました。その後、徐々に他の事例を紹介し、適切な訓練を開始しました。私はモデルが複雑すぎるか、または私の体重を適切に初期化していないと思うので、興味深い結果を表示するのに時間がかかりました。 –

+0

ダニエルありがとうございます。今すぐトレーニングを受けているようです。 GPUの負荷は、GPUがスクリプトを待っているかのようにかなり低いです。 – tinyMind

関連する問題