2016-05-06 13 views
11

グラフに条件付き制御フローが必要です。 predTrueの場合、グラフは変数を更新するopを呼び出す必要があります。それ以外の場合は変数を変更しません。簡略化されたバージョンである:ただし、Iはupdate_x_2tf.condによって選択されていない場合にアサインOPも呼ばれている意味と同じ結果y=[2]にその両方pred=Truepred=Falseリードを見つける`tf.cond`の動作によって混乱します

pred = tf.constant(True) 
x = tf.Variable([1]) 
assign_x_2 = tf.assign(x, [2]) 
def update_x_2(): 
    with tf.control_dependencies([assign_x_2]): 
    return tf.identity(x) 
y = tf.cond(pred, update_x_2, lambda: tf.identity(x)) 
with tf.Session() as session: 
    session.run(tf.initialize_all_variables()) 
    print(y.eval()) 

。これを説明する方法は?そしてこの問題を解決するには?

答えて

17

TL; DR:あなたが分岐の1つに(譲渡など)の副作用を実行するtf.cond()したい場合は、あなたがtf.cond()に渡す関数内で副作用を行い、オペアンプを作成する必要があります。

tf.cond()の動作は少し直感的ではありません。 TensorFlowグラフの実行はグラフ内を順方向に進むので、ブランチで参照するすべての操作は、条件が評価される前に実行する必要があります。これは、真のブランチと偽のブランチの両方がtf.assign()演算子に制御依存関係を持つことを意味します。は、 False`でも、常に2に設定されます。

解決策は、真のブランチを定義する関数の中にtf.assign()を作成することです。たとえば、次のようにコードを構造化することができます。

pred = tf.placeholder(tf.bool, shape=[]) 
x = tf.Variable([1]) 
def update_x_2(): 
    with tf.control_dependencies([tf.assign(x, [2])]): 
    return tf.identity(x) 
y = tf.cond(pred, update_x_2, lambda: tf.identity(x)) 
with tf.Session() as session: 
    session.run(tf.initialize_all_variables()) 
    print(y.eval(feed_dict={pred: False})) # ==> [1] 
    print(y.eval(feed_dict={pred: True})) # ==> [2] 
+0

ええ、それは私を混乱させるものです。 'tf.cond'を実行する前にランタイムはすべての依存関係が実行されるようにします。 'True'ブランチと' False'ブランチでのopsの依存関係は 'cond'の依存関係でもあるので、ブランチのオペレーションは決して実行されないかもしれませんが、その依存関係はすべて実行されます。 –

+1

Yep - グラフのプルーニングは、実行するブランチの可能性のあるすべての依存関係を考慮し、ブランチの1つの内部で定義されていれば実行を禁止します。これは、 'CondContext'がピボットにコントロールの依存関係を追加するためです(https: //github.com/tensorflow/tensorflow/blob/2b2f312cb07765c628d264abe326bfc286f462c1/tensorflow/python/ops/control_flow_ops.py#L1092)、分岐が取られていない場合、その依存関係はデッドテンソル(実行を妨げる)になります。 – mrry

+0

このような理由は何でしたか?非アクティブブランチの背後にあるサブグラフを整理してみてはいかがですか? –

関連する問題