2

私はTF Slim入門書を見ています。私が理解しているところでは、各回の画像データの1バッチ(32画像)しか取り込まれません。明らかに、これをループして、多くの異なるバッチを訓練したいと考えています。イントロはこれをカバーしていません。どのようにこれを適切に行うことができますか。バッチトレーニングイベントを開始するときに自動的に呼び出されるロードバッチ関数を指定する方法があるはずですが、イントロで簡単な例を見つけることはできません。Tensorflow Slimでのバッチトレーニング

# Note that this may take several minutes. 

import os 

from datasets import flowers 
from nets import inception 
from preprocessing import inception_preprocessing 

slim = tf.contrib.slim 
image_size = inception.inception_v1.default_image_size 


def get_init_fn(): 
    """Returns a function run by the chief worker to warm-start the training.""" 
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] 

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 

    variables_to_restore = [] 
    for var in slim.get_model_variables(): 
     excluded = False 
     for exclusion in exclusions: 
      if var.op.name.startswith(exclusion): 
       excluded = True 
       break 
     if not excluded: 
      variables_to_restore.append(var) 

    return slim.assign_from_checkpoint_fn(
     os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 
     variables_to_restore) 


train_dir = '/tmp/inception_finetuned/' 

with tf.Graph().as_default(): 
    tf.logging.set_verbosity(tf.logging.INFO) 

    dataset = flowers.get_split('train', flowers_data_dir) 
    images, _, labels = load_batch(dataset, height=image_size, width=image_size) 

    # Create the model, use the default arg scope to configure the batch norm parameters. 
    with slim.arg_scope(inception.inception_v1_arg_scope()): 
     logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 

    # Specify the loss function: 
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 
    slim.losses.softmax_cross_entropy(logits, one_hot_labels) 
    total_loss = slim.losses.get_total_loss() 

    # Create some summaries to visualize the training process: 
    tf.scalar_summary('losses/Total Loss', total_loss) 

    # Specify the optimizer and create the train op: 
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 
    train_op = slim.learning.create_train_op(total_loss, optimizer) 

    # Run the training: 
    final_loss = slim.learning.train(
     train_op, 
     logdir=train_dir, 
     init_fn=get_init_fn(), 
     number_of_steps=2) 


print('Finished training. Last batch loss %f' % final_loss) 
+0

コードサンプルではload_batch関数が未定義ではありませんか?私はあなたの非常事態に慣れていませんが、私はバッチ処理を理解するためにこの関数を読み始めるでしょう。 – pltrdy

+0

ここにはhttps://github.com/tensorflow/models/blob/master/slim/slim_walkthough.ipynbがありますが、これはバッチを取得すること以外は何もしません。 –

+0

基本的にバッチ処理を繰り返す必要がありますか? – pltrdy

答えて

1

slim.learning.train機能は、トレーニングループが含まれているので、あなたが与えてくれたコードは、画像の複数のバッチで、実際の電車の中で行います。

here in the source codeを参照してください。ここで、train_step_fnはwhileループ内で呼び出されます。 train_step(デフォルト値はtrain_step_fn)には、sess.run([train_op, global_step]...)という行が含まれています。実際には、1回の画像バッチでトレーニング操作が実行されます。

+0

さて、私はload_batch関数にprintステートメントを置き、1つ以上のステップを練習し、ロードバッチ関数が1回だけ呼び出されたので、複数のステップで同じデータが使用されていることがわかりました。 –

+0

さらに、私はlearning.trainの呼び出しでload_batch関数を指定していないので、これを新しいバッチの読み込みにどのように使用するのかは分かりませんでしたか? –

+0

私はこれについてさらに研究しており、バッチが毎回自動的にロードされる場所から確立されるキューがあるようです。これをテストするために私はここに関連する質問がありますhttp://stackoverflow.com/questions/41868871/tensorflow-slim-debugging-during-training。可能であればコメントしてください。 –

関連する問題