2017-03-17 6 views
0

与えられた演算(通常は損失)が依存するすべての変数を見つける方法はありますか? これを使用して、さまざまなset().intersection()の組み合わせを使用して、このコレクションをoptimizer.minimize()またはtf.gradients()に渡します。テンソルフロー演算が依存するすべての変数を見つける

これまでのところ、私はop.op.inputsを発見し、その上の簡単なBFSを試してみましたが、tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)またはslim.get_variables()

によって返された私は、「Tensor.op._idに対応する対応関係があるように思えるんVariableオブジェクト時にチャンスはありませんand Variables.op._id`フィールドがありますが、それは私が頼りにするべきものなのかどうかはわかりません。

これをやりたいとは思わないでしょうか? 私ははもちろん私のグラフを構築する間に変数の私の別々のセットを細かく構築することができますが、私はモデルを変更すると何かを見逃しやすいでしょう。

documentation for tf.Variable.op

答えて

1

は特に明確ではないが、それはthe implementation of a tf.Variableで使用される重要なtf.Operationを参照んtf.Variableに依存する任意のOPは、その操作からのパスになります。以来tf.Operationオブジェクトはハッシュ可能である、あなたが対応するtf.Variableオブジェクトにtf.Operationオブジェクトをマップdictのキーとして使用して、以前のようにBFSを実行することができます。

op_to_var = {var.op: var for var in tf.trainable_variables()} 

starting_op = ... 
dependent_vars = [] 

queue = collections.deque() 
queue.append(starting_op) 

visited = set([starting_op]) 

while queue: 
    op = queue.popleft() 
    try: 
    dependent_vars.append(op_to_var[op]) 
    except KeyError: 
    # `op` is not a variable, so search its inputs (if any). 
    for op_input in op.inputs: 
     if op_input.op not in visited: 
     queue.append(op_input.op) 
     visited.add(op_input.op) 
+0

これは無限ループに問題がある可能性があります...私はこれを試み、それは掛かった。どの 'Op'sがすでに' queue'に入っていたかを追跡する 'set'を追加しました。すぐに戻りました。 – eqzx

+0

あなたは大丈夫です!グラフにサイクルが含まれていれば、元のコードは失敗します。私はそれを更新して 'visited'セットを使用しました。 – mrry

関連する問題