2017-02-02 11 views
1

私はKerasを習得しようとしており、分類問題にLSTMを使用しています。 の精度と損失をプロットし、トレーニング中にプロットを更新したいと考えています。そのために私はcallback functionを使用しています。 何らかの理由でコールバックで受け取った精度と損失が、 と一致しないため、フィット関数によって出力された精度と損失があります。プログラムを実行する場合、私は(各バッチの最初の行は、コールバックで印刷され、第二のmodel.fitで印刷される)、この出力を取得Kerasのコールバックとプログレスバーの精度が一致しません

class PlotCallbacks(Callback): 
    def on_batch_end(self, batch, logs={}): 
     print(logs) 
     return 

# Create the model 
model = Sequential() 
model.add(Embedding(top_words, embedding_vector_length,input_length=max_conv_length)) 
model.add(LSTM(300)) 
model.add(Dense(3, activation='softmax')) 
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
model.fit(x_train, y_train, callbacks=[PlotCallbacks()], nb_epoch=1, batch_size=3, verbose=1) 

ここ

は、自分のコードの関連するラインであります

Epoch 1/1 
{'acc': 0.0, 'loss': 1.1038421, 'batch': 0, 'size': 3} 
3/25 [==>...........................] - ETA: 27s - loss: 1.1038 - acc: 0.0000e+00 

{'acc': 1.0, 'loss': 1.0622898, 'batch': 1, 'size': 3} 
6/25 [======>.......................] - ETA: 19s - loss: 1.0831 - acc: 0.5000  

{'acc': 1.0, 'loss': 0.91526389, 'batch': 2, 'size': 3} 
9/25 [=========>....................] - ETA: 13s - loss: 1.0271 - acc: 0.6667 

{'acc': 1.0, 'loss': 0.36570337, 'batch': 3, 'size': 3} 
12/25 [=============>................] - ETA: 11s - loss: 0.8618 - acc: 0.7500 

{'acc': 1.0, 'loss': 0.1433304, 'batch': 4, 'size': 3} 
15/25 [=================>............] - ETA: 9s - loss: 0.7181 - acc: 0.8000 

{'acc': 1.0, 'loss': 0.041385528, 'batch': 5, 'size': 3} 
18/25 [====================>.........] - ETA: 6s - loss: 0.6053 - acc: 0.8333 

{'acc': 1.0, 'loss': 0.011424608, 'batch': 6, 'size': 3} 
21/25 [========================>.....] - ETA: 3s - loss: 0.5205 - acc: 0.8571 

{'acc': 1.0, 'loss': 0.0034991663, 'batch': 7, 'size': 3} 
24/25 [===========================>..] - ETA: 1s - loss: 0.4558 - acc: 0.8750 

{'acc': 1.0, 'loss': 0.0, 'batch': 8, 'size': 1} 
25/25 [==============================] - 26s - loss: 0.4377 - acc: 0.8800  

私はlogs.get( 'acc')を印刷し、PlotCallbacksオブジェクトのリストに精度を保存してリストを印刷しようとしましたが、問題は残ります。

誰かが問題の原因を手掛かりにしていますか?

おかげ

答えて

1

on_batch_end()タイプのコールバック関数だけで訓練を受けてしまったバッチの精度を取得します。ケラスで印刷されたログは、現在のエポックで見たバッチ全体の平均です。あなたは簡単にあなたのログでそれを観察することができます..最初の2バッチでは、0.01.0の精度があり、これは2バッチ以上の全体精度を0.5000と見なしました。 hereは、平均が計算される場所です。

メトリックとしての精度は、通常、エポックからエポックに報告されるため、コールバックをon_epoch_end()に変更することができます。

+0

非常に良い説明!ありがとうございました – lundeqvist

関連する問題