2016-05-02 13 views
5

テンソルフローのAPIに既にバッチ正規化関数があることに気付きました。私が理解していないことの1つは、トレーニングとテストの間の手順をどのように変更するかです。テンソルフローのバッチ正規化

バッチの正規化は、トレーニング中とテスト中に異なる動作をします。具体的には、トレーニング中に固定平均と分散を使用します。

どこか良いサンプルコードがありますか?私はいくつかを見ましたが、スコープ変数が混乱しています

+0

は、 'tf.contribのような高レベルAPIから事前に定義された層を使用することを検討してください.layers'。 – danijar

答えて

9

そうです、tf.nn.batch_normalizationは、バッチ正規化を実装するための基本的な機能を提供します。トレーニング中に移動する手段と分散を追跡する余分なロジックを追加し、推論の際に訓練された平均と分散を使用する必要があります。あなたは非常に一般的な実装のために、このexampleで見ることができますが、gammaを使用しない迅速なバージョンはここにある:

beta = tf.Variable(tf.zeros(shape), name='beta') 
    moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean', 
           trainable=False) 
    moving_variance = tf.Variable(tf.ones(shape), 
            name='moving_variance', 
            trainable=False) 
    control_inputs = [] 
    if is_training: 
    mean, variance = tf.nn.moments(image, [0, 1, 2]) 
    update_moving_mean = moving_averages.assign_moving_average(
     moving_mean, mean, self.decay) 
    update_moving_variance = moving_averages.assign_moving_average(
     moving_variance, variance, self.decay) 
    control_inputs = [update_moving_mean, update_moving_variance] 
    else: 
    mean = moving_mean 
    variance = moving_variance 
    with tf.control_dependencies(control_inputs): 
    return tf.nn.batch_normalization(
     image, mean=mean, variance=variance, offset=beta, 
     scale=None, variance_epsilon=0.001) 
+0

ありがとうございます。もう一つの簡単な質問。ガンマバージョンは本当に複雑ですか?あなたはそれのために別のtf.Variableを初期化しなければならないだろうか?それ以外のコードは同じでなければなりません。 – user3358117

+0

はい、「ガンマ」を追加するために提供したリンクのより一般的な実装に従うことができます。 – keveman