2017-12-09 13 views
1

model.fit()メソッドを使用してケラスでモデルをトレーニングしています。 各訓練エポック後に別々に検証する必要がある複数の検証セットを使用したいので、検証セットごとに1つの損失値が得られます。可能であれば、それらはトレーニング中に表示され、keras.callbacks.History()コールバックによって返されます。ケラで複数のバリデーションセットを使用する

私はこのような何かを考えています:

history = model.fit(train_data, train_targets, 
        epochs=epochs, 
        batch_size=batch_size, 
        validation_data=[ 
         (validation_data1, validation_targets1), 
         (validation_data2, validation_targets2)], 
        shuffle=True) 

私は現在、これを実装する方法は考えています。私自身のCallbackを書いてこれを達成することは可能でしょうか?それ以外にどのようにこの問題にアプローチしますか?

答えて

2

問題を解決するためにHistoryコールバックに基づいて私自身がCallbackを作成しました。私はこれが最良のアプローチであるかどうかはわかりませんが、Callbackは、Historyコールバック、コンストラクタに渡される追加の検証セットのメトリックと同様に、トレーニングと検証のための損失とメトリックを記録します。

私は、このように使用しています
class AdditionalValidationSets(Callback): 
    def __init__(self, validation_sets, verbose=0, batch_size=None): 
     """ 
     :param validation_sets: 
     a list of 3-tuples (validation_data, validation_targets, validation_set_name) 
     or 4-tuples (validation_data, validation_targets, sample_weights, validation_set_name) 
     :param verbose: 
     verbosity mode, 1 or 0 
     :param batch_size: 
     batch size to be used when evaluating on the additional datasets 
     """ 
     super(AdditionalValidationSets, self).__init__() 
     self.validation_sets = validation_sets 
     for validation_set in self.validation_sets: 
      if len(validation_set) not in [2, 3]: 
       raise ValueError() 
     self.epoch = [] 
     self.history = {} 
     self.verbose = verbose 
     self.batch_size = batch_size 

    def on_train_begin(self, logs=None): 
     self.epoch = [] 
     self.history = {} 

    def on_epoch_end(self, epoch, logs=None): 
     logs = logs or {} 
     self.epoch.append(epoch) 

     # record the same values as History() as well 
     for k, v in logs.items(): 
      self.history.setdefault(k, []).append(v) 

     # evaluate on the additional validation sets 
     for validation_set in self.validation_sets: 
      if len(validation_set) == 3: 
       validation_data, validation_targets, validation_set_name = validation_set 
       sample_weights = None 
      elif len(validation_set) == 4: 
       validation_data, validation_targets, sample_weights, validation_set_name = validation_set 
      else: 
       raise ValueError() 

      results = self.model.evaluate(x=validation_data, 
              y=validation_targets, 
              verbose=self.verbose, 
              sample_weight=sample_weights, 
              batch_size=self.batch_size) 

      for i, result in enumerate(results): 
       if i == 0: 
        valuename = validation_set_name + '_loss' 
       else: 
        valuename = validation_set_name + '_' + self.model.metrics[i-1].__name__ 
       self.history.setdefault(valuename, []).append(result) 

history = AdditionalValidationSets([(validation_data2, validation_targets2, 'val2')]) 
model.fit(train_data, train_targets, 
      epochs=epochs, 
      batch_size=batch_size, 
      validation_data=(validation_data1, validation_targets1), 
      callbacks=[history] 
      shuffle=True) 
関連する問題