2017-10-10 18 views
1

Tensorflowでは、特定の出力テンソルを評価するために必要なプレースホルダテンソルをすべて見つける方法はありますか?つまり、sess.run(output_tensor)を呼び出すとfeed_dictに入力する必要があるすべての(プレースホルダの)テンソルを返す関数がありますか?ここでTensorflowグラフに必要なプレースホルダを見つける

私は擬似コードで、やってみたいものの例です:

import tensorflow as tf 

a = tf.placeholder(dtype=tf.float32,shape=()) 
b = tf.placeholder(dtype=tf.float32,shape=()) 
c = tf.placeholder(dtype=tf.float32,shape=()) 
d = a + b 
f = b + c 

# This should return [a,b] or [a.name,b.name] 
d_input_tensors = get_dependencies(d) 

# This should return [b,c] or [b.name,c.name] 
f_input_tensors = get_dependencies(f) 

EDIT:は明確にするために、私は(必ずしも)ないです、グラフ内のプレースホルダのすべてを探して、ちょうど特定の出力テンソルを定義するのに必要なもの。目的のプレースホルダは、グラフ内のすべてのプレースホルダのサブセットにすぎません。

+0

グラフ内のすべてのプレースホルダを取得するには、https://stackoverflow.com/a/44371483/4834515という回答があります。依存関係を得るために...考えていない。 – Seven

+0

@ Seven私はすべてのプレースホルダではなく、依存関係を取得したいと思っています。私は明確にするために私の質問を編集します。 –

答えて

0

いくつかの工夫の後とthisほぼ同じSO質問を発見し、私は次の解決策を考え出した:

def get_tensor_dependencies(tensor): 

    # If a tensor is passed in, get its op 
    try: 
     tensor_op = tensor.op 
    except: 
     tensor_op = tensor 

    # Recursively analyze inputs 
    dependencies = [] 
    for inp in tensor_op.inputs: 
     new_d = get_tensor_dependencies(inp) 
     non_repeated = [d for d in new_d if d not in dependencies] 
     dependencies = [*dependencies, *non_repeated] 

    # If we've reached the "end", return the op's name 
    if len(tensor_op.inputs) == 0: 
     dependencies = [tensor_op.name] 

    # Return a list of tensor op names 
    return dependencies 

:これはまた、変数と定数をプレースホルダを返しますが、しません。 dependencies = [tensor_op.name]dependencies = [tensor_op.name] if tensor_op.type == 'Placeholder' else []に置き換えられた場合は、プレースホルダのみが返されます。

関連する問題