2017-02-13 12 views
0

こんにちは。私はtfでかなり新しく、質問はバッチ正規化を推論に適用する方法です。私は訓練とテストの間にテンソルフローによるバッチ正規化を適用しました。コードは次のとおりです。ご覧のとおり、私はトレーニングとテストを異なる平均値とvar値で扱います。今度はトレーニングが完了し、モデルを実践に適用したいと思います。デモを使用するなど、ckptファイルをロードし、テストするだけです。この場合、どのようにしてこの1つのケースを正規化できますか?トレーニングの各エポック後にBNの平均値と平均値を保存して後で復元する方法はありますか?本当にありがとう!Tensorflowバッチ正規化を推論に適用する方法

fc_mean, fc_var = tf.nn.moments(
      input, 
      axes=[0], 
     ) 
     scale = tf.Variable(tf.ones([out_size])) 
     shift = tf.Variable(tf.zeros([out_size])) 
     epsilon = 0.001 
     ema = tf.train.ExponentialMovingAverage(decay=0.5) 
     def mean_var_with_update(): 
      ema_apply_op = ema.apply([fc_mean, fc_var]) 
      with tf.control_dependencies([ema_apply_op]): 
       return tf.identity(fc_mean), tf.identity(fc_var) 
     mean, var = tf.cond(train_phase, 
          mean_var_with_update, 
          lambda: (ema.average(fc_mean), 
             ema.average(fc_var) 
            ) 
          ) 
     input_BN = tf.nn.batch_normalization(input, mean, var, shift, scale, epsilon) 

答えて

0

これは正しいと思いますか?私は訓練中に得たシフト(オフセット)とスケールを保存しました。モデルを実際に適用する必要があるときは、平均とvarを置き換えるためにshiftとscaleを使用します。そして、テストケースのシフトとvarをゼロに設定します。

if not test: 
     input_BN = tf.nn.batch_normalization(input, mean, var, shift, scale, epsilon) 
    else: 
     input_BN = tf.nn.batch_normalization(input, shift, scale, 0, 0, epsilon) 
+0

残念ですが、今はちょっと混乱しています。この論文では、scale = sqrt(var)、shift = meanを設定することで、元のアクティベーションを回復できると述べています。だからテンソルフローによる訓練の間に、モデルがうまく収束した後に、スケールとシフトの価値はどうなるでしょうか? – syntaxError

関連する問題