tf.cond()
機能は後はブール式を評価し、一つだけ真または偽の枝を実行するように設計されています。あなたはどんな副作用のOPはあなたがtf.cond()
に渡すことlambda
(または同等の機能)内で定義されていることを確認しなければならないこと、しかし
v = tf.Variable(0)
condition = tf.placeholder(tf.bool, shape=[])
op_to_run = tf.cond(condition, lambda: v.assign_add(1), lambda: v.assign_sub(1))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(v)) # ==> "0"
for _ in range(3):
sess.run(op_to_run, feed_dict={condition: True})
print(sess.run(v)) # ==> "3"
for _ in range(5):
sess.run(op_to_run, feed_dict={condition: False})
print(sess.run(v)) # ==> "-2"
注:たとえば、場合、次のようなものを書きました。副作用OPSが定義されている場合外tf.cond()
は、彼らは無条件に実行されます:明確化および注意のため
v = tf.Variable(0)
condition = tf.placeholder(tf.bool, shape=[])
# N.B. DO NOT DO THIS! Both side-effecting ops are defined outside the `tf.cond()`
# so they will both execute, regardless of the condition.
inc_op = v.assign_add(1)
dec_op = v.assign_sub(1)
op_to_run = tf.cond(condition, lambda: inc_op, lambda: dec_op)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(v)) # ==> "0"
# Both the `assign_add()` and `assign_sub()` will run, cancelling each other out.
for _ in range(3):
sess.run(op_to_run, feed_dict={condition: True})
print(sess.run(v)) # ==> "0"
# Both the `assign_add()` and `assign_sub()` will run, cancelling each other out.
for _ in range(5):
sess.run(op_to_run, feed_dict={condition: False})
print(sess.run(v)) # ==> "0"
感謝を。ラムダの中には、あるキューからサンプルをデキューするか、別のキューからデキューしたサンプルをデキューするコマンドがありました。ブール値のプレースホルダを渡して、どちらが発生するのかを確かめました。 –
@DavidParks深みを持たせたい場合は、「スイッチ/マージ」の仕組みを見ることができます。それらは 'tf.cond'によって追加されますが、これらの文を挿入することでグラフの部分を条件付きで実行させることもできます。https://gist.github.com/yaroslavvb/d67410e240369736fc4ba0267250ef27マージロジックの説明は次のとおりです。ここ - https://github.com/tensorflow/tensorflow/blob/01daba61e3a5099c6ad6439fa47e30c71560f06b/tensorflow/core/common_runtime/executor.cc#L2391 –