バッチ正規化操作の少し微調整されたバージョンを実装しようとしています。私は平均と分散のような移動平均値を明示的に保つ必要があります。これを行うために、私はTensorflowでの割り当てと制御の依存メカニズムを使っていくつかの実験をしています。私は不思議な問題にぶつかります。私は次のおもちゃのコードを持っています。ここで私が意図したとおりtf.control_dependencies
仕事かどうかをテストしようとしています:Tensorflow:tf.assignは何も割り当てません。
dataset = MnistDataSet(validation_sample_count=10000,
load_validation_from="validation_indices")
samples, labels, indices_list, one_hot_labels =
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE)
samples = np.expand_dims(samples, axis=3)
flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR)
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32),
trainable=False, dtype=tf.float32)
a = tf.Variable(name="a", initial_value=5.0, trainable=False)
b = tf.Variable(name="b", initial_value=4.0, trainable=False)
c = tf.Variable(name="c", initial_value=0.0, trainable=False)
batch_mean, batch_var = tf.nn.moments(flat_data, [0])
b_op = tf.assign(b, a)
mean_op = tf.assign(mean, batch_mean)
with tf.control_dependencies([b_op, mean_op]):
c = a + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples})
私は単に784件の寸法を有する各エントリにデータのバッチをロードし、その瞬間を計算し、変数にbatch_mean
を保存しようとしていますmean
。私は自明に変数a
の値をb
に保存します。
最後の行で、c
とmean
の値のグラフを実行すると、期待値であるc
と表示されます。しかし、mean
は依然として100のベクトルであり、バッチ平均を含まない。それは実行されていないmean_op = tf.assign(mean, batch_mean)
のようなものです。
この理由は何ですか?私が知る限り、tf.control_dependencies
コールのすべての操作は、以下のコンテキストでの操作の前に実行する必要があります。文脈にあるc
をここに明示的に呼び出します。何か不足していますか?
私は完全に理解していませんでした。私の理解では、sess.run()はすべての操作とその依存関係をグラフで実行する必要があります。私は明示的にcと意味をsess.run()で述べます。なぜ変数 "平均"の更新をスキップするのですか? –
両方のopsが評価されますが、セッションが 'mean' **の前に**' c'を評価すると、古い値が返されます。再度評価することによって 'mean'が更新されたことを確認することができます。 – Maxim
ああ、グラフは、正しい値が割り当てられる前にフェッチされます。それでは同じことが 'c'に起こる可能性がありますよね? –