2017-09-26 16 views
8
def train(): 
# Model 
model = Model() 

# Loss, Optimizer 
global_step = tf.Variable(1, dtype=tf.int32, trainable=False, name='global_step') 
loss_fn = model.loss() 
optimizer = tf.train.AdamOptimizer(learning_rate=TrainConfig.LR).minimize(loss_fn, global_step=global_step) 

# Summaries 
summary_op = summaries(model, loss_fn) 

with tf.Session(config=TrainConfig.session_conf) as sess: 

    # Initialized, Load state 
    sess.run(tf.global_variables_initializer()) 
    model.load_state(sess, TrainConfig.CKPT_PATH) 

    writer = tf.summary.FileWriter(TrainConfig.GRAPH_PATH, sess.graph) 

    # Input source 
    data = Data(TrainConfig.DATA_PATH) 

    loss = Diff() 
    for step in xrange(global_step.eval(), TrainConfig.FINAL_STEP): 

      mixed_wav, src1_wav, src2_wav, _ = data.next_wavs(TrainConfig.SECONDS, TrainConfig.NUM_WAVFILE, step) 

      mixed_spec = to_spectrogram(mixed_wav) 
      mixed_mag = get_magnitude(mixed_spec) 

      src1_spec, src2_spec = to_spectrogram(src1_wav), to_spectrogram(src2_wav) 
      src1_mag, src2_mag = get_magnitude(src1_spec), get_magnitude(src2_spec) 

      src1_batch, _ = model.spec_to_batch(src1_mag) 
      src2_batch, _ = model.spec_to_batch(src2_mag) 
      mixed_batch, _ = model.spec_to_batch(mixed_mag) 

      # Initializae our callback. 
      #early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.5) 


      l, _, summary = sess.run([loss_fn, optimizer, summary_op], 
            feed_dict={model.x_mixed: mixed_batch, model.y_src1: src1_batch, 
               model.y_src2: src2_batch}) 

      loss.update(l) 
      print('step-{}\td_loss={:2.2f}\tloss={}'.format(step, loss.diff * 100, loss.value)) 

      writer.add_summary(summary, global_step=step) 

      # Save state 
      if step % TrainConfig.CKPT_STEP == 0: 
       tf.train.Saver().save(sess, TrainConfig.CKPT_PATH + '/checkpoint', global_step=step) 

    writer.close() 

私は、.wavファイルの音声から音楽を分離するこの神経ネットワークコードを持っています。 列車区間を停止するための早期停止アルゴリズムを導入するにはどうすればよいですか?私は、ValidationMonitorに関するいくつかのプロジェクトを見ています。誰か助けてくれますか?テンソルフローで早期停止を実装する方法

答えて

0

ValidationMonitorは推奨されなくなりました。お勧めしません。それでもあなたはそれを使うことができます。

validation_monitor = monitors.ValidationMonitor(
     input_fn=functools.partial(input_fn, subset="evaluation"), 
     eval_steps=128, 
     every_n_steps=88, 
     early_stopping_metric="accuracy", 
     early_stopping_rounds = 1000 
    ) 

、あなたはここで、自分で自分の私の実装を実現することができます:ここ は1を作成する方法の一例である

  if (loss_value < self.best_loss): 
      self.stopping_step = 0 
      self.best_loss = loss_value 
      else: 
      self.stopping_step += 1 
      if self.stopping_step >= FLAGS.early_stopping_step: 
      self.should_stop = True 
      print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value)) 
      run_context.request_stop() 
関連する問題