2017-05-05 22 views
1

私はPythonとKerasを初めて使いました。私はエポックごとにウェイトファイルを保存するニューラルネットワークを構築しました。しかし、より細かく(私は時系列で層の重量分布を可視化しています)、すべての時代ではなく、N回のバッチごとに重みを保存したいと思います。Python/Keras - N個のバッチごとにモデルウェイトを保存します。

誰にも提案はありますか?

答えて

4

独自のコールバック(https://keras.io/callbacks/)を作成することができます。何かのように:私はself.batchの代わりに、各エポックで0で後で再起動するために提供さbatch引数を使用

from keras.callbacks import Callback 

class WeightsSaver(Callback): 
    def __init__(self, model, N): 
     self.model = model 
     self.N = N 
     self.batch = 0 

    def on_batch_end(self, batch, logs={}): 
     if self.batch % self.N == 0: 
      name = 'weights%08d.h5' % self.batch 
      self.model.save_weights(name) 
     self.batch += 1 

その後、あなたのフィットコールに追加します。たとえば、5バッチごとにウェイトを保存するには:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)]) 
関連する問題