2017-11-09 2 views
2

バッチ正規化操作の少し微調整されたバージョンを実装しようとしています。私は平均と分散のような移動平均値を明示的に保つ必要があります。これを行うために、私はTensorflowでの割り当てと制御の依存メカニズムを使っていくつかの実験をしています。私は不思議な問題にぶつかります。私は次のおもちゃのコードを持っています。ここで私が意図したとおりtf.control_dependencies仕事かどうかをテストしようとしています:Tensorflow:tf.assignは何も割り当てません。

dataset = MnistDataSet(validation_sample_count=10000, 
load_validation_from="validation_indices") 
samples, labels, indices_list, one_hot_labels = 
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE) 
samples = np.expand_dims(samples, axis=3) 

flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR) 
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32), 
       trainable=False, dtype=tf.float32) 
a = tf.Variable(name="a", initial_value=5.0, trainable=False) 
b = tf.Variable(name="b", initial_value=4.0, trainable=False) 
c = tf.Variable(name="c", initial_value=0.0, trainable=False) 
batch_mean, batch_var = tf.nn.moments(flat_data, [0]) 

b_op = tf.assign(b, a) 
mean_op = tf.assign(mean, batch_mean) 
with tf.control_dependencies([b_op, mean_op]): 
    c = a + b 

init = tf.global_variables_initializer() 
sess = tf.Session() 
sess.run(init) 

results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples}) 

私は単に784件の寸法を有する各エントリにデータのバッチをロードし、その瞬間を計算し、変数にbatch_meanを保存しようとしていますmean 。私は自明に変数aの値をbに保存します。

最後の行で、cmeanの値のグラフを実行すると、期待値であるcと表示されます。しかし、meanは依然として100のベクトルであり、バッチ平均を含まない。それは実行されていないmean_op = tf.assign(mean, batch_mean)のようなものです。

この理由は何ですか?私が知る限り、tf.control_dependenciesコールのすべての操作は、以下のコンテキストでの操作の前に実行する必要があります。文脈にあるcをここに明示的に呼び出します。何か不足していますか?

答えて

3

known "feature"tf.Session.run()です。 cmeanは独立しているので、meanc(これはmeanを更新する)の前に評価されます。

は、ここでは、この効果の短いバージョンです:

a = tf.Variable(name="a", initial_value=1.0, trainable=False) 
b = tf.Variable(name="b", initial_value=0.0, trainable=False) 
dependent_op = tf.assign(b, a * 3) 
with tf.control_dependencies([dependent_op]): 
    c = a + 1 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run([c, b])) 
    print(sess.run([b])) 

bの第二の評価は[3.0]を返すことが保証されます。しかし、最初のrun[2.0 3.0]または[2.0 0.0]のいずれかを返すことがあります。

+0

私は完全に理解していませんでした。私の理解では、sess.run()はすべての操作とその依存関係をグラフで実行する必要があります。私は明示的にcと意味をsess.run()で述べます。なぜ変数 "平均"の更新をスキップするのですか? –

+0

両方のopsが評価されますが、セッションが 'mean' **の前に**' c'を評価すると、古い値が返されます。再度評価することによって 'mean'が更新されたことを確認することができます。 – Maxim

+0

ああ、グラフは、正しい値が割り当てられる前にフェッチされます。それでは同じことが 'c'に起こる可能性がありますよね? –

関連する問題