2017-05-15 8 views
0

データをネットワークに適切に供給するためのキューの使用方法に関する多くの研究を行っています。しかし、私はインターネット上で解決策を見つけることができませんでした。Tensorflowでキューを使用してネットワークにデータを送信する際の検証とトレーニングのグラフ

現在のところ、私のコードはトレーニングデータを読み込んでトレーニングを行うことができますが、検証とテストは行われません。ここで重要な行が私のコードを構成します:

images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs) 

print("Initiliaze training") 
logits = utils.inference(images) 
loss_intermediate, loss = utils.get_loss(logits, volumes) 

train_optimizer = utils.pre_training(loss, FLAGS.learning_rate) 

summary_train = tf.summary.merge_all('train') 
summary_test = tf.summary.merge_all('test') 

init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 

saver = tf.train.Saver(max_to_keep=2) 
with tf.Session() as sess: 

    summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph) 
    summary_writer_test = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run_test, sess.graph) 
    sess.run(init) 

    # Start input enqueue threads. 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

    print("Start training") 

    try: 
     step = 0 
     while not coord.should_stop(): 
      start_time = time.time() 

      _, loss_intermediate_value, loss_value = sess.run([train_optimizer, loss_intermediate, loss]) 
      duration = time.time() - start_time 
      if step % FLAGS.show_step == 0: 
       print('Step %d: loss_intermediate = %.2f, loss = %.5f (%.3f sec)' % (step, loss_intermediate_value, loss_value, duration)) 
       summary_str = sess.run(summary_train) 
       summary_writer.add_summary(summary_str, step) 
       summary_writer.flush() 

      if step % FLAGS.test_interval == 0: 
       ###### HERE VALIDATION HOW ? ############ 
      step += 1 
    except tf.errors.OutOfRangeError: 
     print('ERROR IN CODE') 
    finally: 
     print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
     # When done, ask the threads to stop. 
     coord.request_stop() 
     # Wait for threads to finish. 
     coord.join(threads) 

この機能は、データの読み取りに使用されます。

def inputs(train, batch_size, num_epochs): 

    if not num_epochs: num_epochs = None 
    filename = os.path.join(train) 

    with tf.name_scope('input'): 
    filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs) 

    image, volume = read_and_decode(filename_queue) 

    images, volumes = tf.train.shuffle_batch([image, volume], batch_size=batch_size, num_threads=2, capacity=1000 * batch_size, min_after_dequeue=500) 

    return images, volume 

テンソルフローを使用して検証を行うために別の入力キューまたは入力グラフを作成する方法を理解できません。誰か助けてくれますか?どんな助けもありがとう!あなたはすでにあなたがしなければならないすべてはのための別の入力パイプラインを作成してあり、訓練および検証データセットにデータ分割を持って提供

EDIT

def _conv(self, inputs, nb_filter, kernel_size=1, strides=1, pad='VALID', name='conv'): 
     with tf.name_scope(name) as scope: 

      #kernel = tf.Variable(tf.truncated_normal([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)], mean=0.0, stddev=0.0001), name='weights') 
      kernel = tf.Variable(tf.contrib.layers.xavier_initializer(uniform=False)([kernel_size, kernel_size,int(inputs.get_shape().as_list()[3]),int(nb_filter)]), name='weights') 
      conv = tf.nn.conv2d(inputs, kernel, [1,strides,strides,1], padding=pad, data_format='NHWC') 
      return conv 

EDIT 2

with tf.Graph().as_default(): 
    print("Load Data...") 
    images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs) 
    v_images, v_volumes = utils.inputs(FLAGS.val_file_path, FLAGS.batch_size) 

    print("input shape: " + str(images.get_shape())) 
    print("output shape: " + str(volumes.get_shape())) 

    print("Initialize training") 
    logits = utils.inference(images, FLAGS.stacks, True) 
    v_logits = utils.inference(v_images, FLAGS.stacks, False) 

    tf.add_to_collection("logits", v_logits) 

    loss = utils.get_loss(logits, volumes, FLAGS.stacks, 'train') 
    v_loss = utils.get_loss(v_logits, v_volumes, FLAGS.stacks, 'val') 

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
    with tf.control_dependencies(update_ops): 
     train_optimizer = utils.pre_training(loss, FLAGS.learning_rate) 

    validate = utils.validate(v_images, v_logits, v_volumes, FLAGS.scale) 

    summary_train_op = tf.summary.merge_all('train') 
    summary_val_op = tf.summary.merge_all('val') 

    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 

    saver = tf.train.Saver(max_to_keep=2) 
    with tf.Session() as sess: 

     summary_writer = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run, sess.graph) 
     summary_writer_val = tf.summary.FileWriter(FLAGS.train_dir + FLAGS.run + FLAGS.run_val, sess.graph) 
     sess.run(init) 

     # Start input enqueue threads. 
     coord = tf.train.Coordinator() 
     threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

     try: 
      print("Start training") 
      step = 0 
      while not coord.should_stop(): 

       start_time = time.time() 
       _, loss_list, image_batch, volume_batch, summary_str = sess.run([train_optimizer, loss, images, volumes, summary_train_op]) 
       duration = time.time() - start_time 

       if (step + 1) % FLAGS.show_step == 0: 
        print('Step %d: (%.3f sec)' % (step, duration), end= ': ') 
        print (", ".join('%.5f'%float(x) for x in loss_list)) 
        summary_writer.add_summary(summary_str, step) 

       if (step + 1) % FLAGS.val_interval == 0: 

        val_loss_sum_list = [0] * len(v_loss) 

        for val_step in range(0, FLAGS.val_iter): 
         _, val_loss_list, summary_str_val, image_input, volume_estimated, volume_ground_truth = sess.run([validate, v_loss, summary_val_op, v_images, v_logits, v_volumes]) 
         val_loss_sum_list = [sum(x) for x in zip(val_loss_sum_list, val_loss_list)] 

         if (val_step + 1) == FLAGS.val_iter: 
          print('Validation Interval %d: ' % (step/FLAGS.val_interval), end= '') 
          print (", ".join('%.5f'%float(x/FLAGS.val_iter) for x in val_loss_sum_list)) 
          summary_writer_val.add_summary(summary_str_val, step) 

          #image_input, volume_estimated, volume_ground_truth = sess.run([v_images, v_logits, v_volumes]) 
          #summary_val_images_op = utils.validate(image_input, volume_estimated, volume_ground_truth, FLAGS.scale, int(step/FLAGS.val_interval)) 

       if (step + 1) % FLAGS.step_save_checkpoint == 0: 
        checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, 'hourglass-model') 
        saver.save(sess, checkpoint_file, global_step=step) 
        print('Step: ' + str(step)) 
        print('Saved: ' + checkpoint_file) 

       step += 1 
     except tf.errors.OutOfRangeError: 
      print('OUT OF RANGE ERROR') 
     except Exception as e: 
      print(sys.exc_info()) 
      print('Unexpected error in code') 
      exc_type, exc_obj, exc_tb = sys.exc_info() 
      fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] 
      print(exc_type, fname, exc_tb.tb_lineno) 
     finally: 
      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 
      checkpoint_file = os.path.join(FLAGS.train_dir + FLAGS.run, '-model') 
      saver.save(sess, checkpoint_file, global_step=step) 
      print('Step: ' + str(step)) 
      print('Saved: ' + checkpoint_file) 

      # When done, ask the threads to stop. 
      coord.request_stop() 
      # Wait for threads to finish. 
      coord.join(threads) 
+0

私は似たような目的で行ったことがありますが、画像ではなくCSVファイルから読み取るためには、トレーニングループがテスト/検証ファイルでそれらを指すようになってからキューランナーを再起動するだけです。それはうまく動作するように見える –

+0

私はトレーニングthoの間にトレーニングをしたい! @VS_FF – thigi

+0

ahh okok。はい、このやり方は後でしか行いません –

答えて

1

検証データ提供したコードを使用すると、次のようになります

 
images, volumes = utils.inputs(FLAGS.train_file_path, FLAGS.batch_size, FLAGS.num_epochs) 
# create validation pipeline 
v_images, v_volumes = utils.inputs(FLAGS.valid_file_path, FLAGS.batch_size, None) 

logits = utils.inference(images) 
loss_intermediate, loss = utils.get_loss(logits, volumes) 
# define validation ops 
v_logits = utils.inference(v_images) 
accuracy = utils.accuracy(v_logits, v_volumes) 

... a bunch of code here ... 

with tf.Session() as sess: 
    ... more code here ... 
    if step % FLAGS.test_interval == 0: 
     acc = sess.run([accuracy]) 
     print('Accuracy on validation data: {}'.format(acc)) 
    ... more code here ... 

これはあなたが探していたものですか?

関連する問題