考えられる回答の1つはmake_template
です。これはhttps://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/template_test.pyで概要を示しています。
training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
tf.set_random_seed(1234)
def test_line(x):
m = tf.get_variable("w", shape=[],
initializer=tf.truncated_normal_initializer())
b = tf.get_variable("b", shape=[],
initializer=tf.truncated_normal_initializer())
return x * m + b
line_template = template.make_template("line", test_line)
train_prediction = line_template(training_input)
test_prediction = line_template(test_input)
train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(train_loss)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
initial_test_loss = sess.run(test_loss)
sess.run(train_op)
final_test_loss = sess.run(test_loss)
# Parameters are tied, so the loss should have gone down when we trained it.
self.assertLess(final_test_loss, initial_test_loss)
['tf.cond'は解決策ではありません](https://groups.google.com/a/tensorflow.org/forum/#!msg/discuss)/mLrt5qc9_uU/sGNbC7GpAwAJ) どちらの側の評価でも、両方のキューを介して物事をプルします。 – mdaoust
'tf.cond'は引き続き使用できますが、ハックがあります。代わりに、 'QueueBase.from_list'があります。参照:https://github.com/tensorflow/tensorflow/issues/2514 - sigh .. – TimZaman