2017-06-07 12 views
1

Tensorflowを学習していて、コードを正しく構造化しようとしています。私は多かれ少なかれ、グラフを裸で、またはクラスメソッドとして構築する方法を知っていますが、コードをどのように構造化するのが最適かを考えようとしています。Tensorflowグラフを関数内に作成する

def build_graph():     
    g = tf.Graph()  
    with g.as_default():      
     a = tf.placeholder(tf.int8) 
     b = tf.add(a, tf.constant(1, dtype=tf.int8)) 
    return g 

graph = build_graph() 
with tf.Session(graph=graph) as sess: 
    feed = {a: 3}  
    print(sess.run(b, feed_dict=feed)) 

ちょうど私がそれを行うとき、私はエラーを取得、しかし4をプリントアウトする必要があります:私はこれがあるためであるかなり確信している

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph. 

を私は簡単な例を試してみましたbuild_graphの内部のプレースホルダはプライベートですが、with tf.Session(graph=graph)はそれを処理してはいけませんか?このような状況でフィードディクテーションを使用するより良い方法はありますか?

答えて

0

私はより良い方法を探していますので、私の答えはおそらく最高ではありません。あなたは、このような

a = tf.placeholder(tf.int8, name='a') 
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b') 

として、名前をaを与えるとb場合にもかかわらず、そして、あなたは

graph = build_graph() 

a = graph.get_tensor_by_name('a:0') 
b = graph.get_tensor_by_name('b:0') 

with tf.Session(graph=graph) as sess: 
    feed = {a: 3}  
    print(sess.run(b, feed_dict=feed)) 

P.S.を行うことができますabという名前は必要ありません。後で参照するほうが簡単です。また、より良い解決策を見つけた場合は、それも共有してください。

3

いくつかのオプションがあります。

オプション1:テンソル自体の代わりにテンソルの名前を渡すだけです。この場合

with tf.Session(graph=graph) as sess: 
    feed = {"Placeholder:0": 3}  
    print(sess.run("Add:0", feed_dict=feed)) 

、それは代わりに、上記のように、デフォルトの名前を使用するのでは、ノードに意味のある名前を与えるために、おそらく最善です:

def build_graph(): 
    g = tf.Graph() 
    with g.as_default(): 
     a = tf.placeholder(tf.int8, name="a") 
     b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b") 
    return g 

graph = build_graph() 
with tf.Session(graph=graph) as sess: 
    feed = {"a:0": 3} 
    print(sess.run("b:0", feed_dict=feed)) 

リコール"foo"という名前の操作の出力は"foo:0"という名前のテンソルであることを、"foo:1"など。ほとんどの操作には1つの出力しかありません。

オプション2build_graph()関数はすべての重要なノードを返します。

def build_graph(): 
    g = tf.Graph() 
    with g.as_default(): 
     a = tf.placeholder(tf.int8) 
     b = tf.add(a, tf.constant(1, dtype=tf.int8)) 
    return g, a, b 

graph, a, b = build_graph() 
with tf.Session(graph=graph) as sess: 
    feed = {a: 3} 
    print(sess.run(b, feed_dict=feed)) 

オプション3:コレクションに重要なノードを追加

def build_graph(): 
    g = tf.Graph() 
    with g.as_default(): 
     a = tf.placeholder(tf.int8) 
     b = tf.add(a, tf.constant(1, dtype=tf.int8)) 
    for node in (a, b): 
     g.add_to_collection("important_stuff", node) 
    return g 

graph = build_graph() 
a, b = graph.get_collection("important_stuff") 
with tf.Session(graph=graph) as sess: 
    feed = {a: 3} 
    print(sess.run(b, feed_dict=feed)) 

はオプション4:あなたはget_tensor_by_name()

def build_graph(): 
    g = tf.Graph() 
    with g.as_default(): 
     a = tf.placeholder(tf.int8, name="a") 
     b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b") 
    return g 

graph = build_graph() 
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")] 
with tf.Session(graph=graph) as sess: 
    feed = {a: 3} 
    print(sess.run(b, feed_dict=feed)) 

を使用することができます@poheにより示唆されるように、私は個人的にオプションを使用します2最も頻繁に、それはかなり簡単で、名前で遊ぶことを必要としません。グラフが大きい場合はオプション3を使用し、コレクションがモデルと共に保存されるため、長い間生きていきます。実際に重要なことを文書化するのはすばやい方法です。私は実際にオブジェクトへの参照を持つことを好むため、実際にはオプション1を使用しません(理由はわかりません)。オプション4は、他の人が作成したグラフを使用して作業しており、テンソルを直接参照していない場合に便利です。

希望すると便利です。

関連する問題