2016-06-12 5 views
5

グラフがテンソルフローでどのように機能するのか、それらにアクセスするのは難しいと思います。私の直感は、 'with graph:'の下にある線が単一の実体としてグラフを形成するということです。したがって、私は、インスタンス化するとグラフを作成し、グラフを実行する関数を持つクラスを作成することにしました。Tensorflow:クラス内のグラフを作成して実行する

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      prediction = ... 
      cost  = ... 
      optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(optimizer, feed_dict) 
      loss = sess.run(cost, feed_dict) 
      ... 
     return variables 

次のステップでは、グラフを構築し、それを実行するために、クラスに渡すパラメータを組み立てますメインファイルを作成することです。

#Main file 
... 
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... } 

#Building graph 
G = Graph(parameters_dict) 
P = G.launchG(Input) 
... 

これは私にとって非常にエレガントですが、それは非常に(明らかに)動作しません。実際、launchG関数は、グラフに定義されたノードにアクセスできないようで、次のようなエラーが出るようです。

---> 26 sess.run(optimizer, feed_dict) 

NameError: name 'optimizer' is not defined 

おそらくそれは私のpython(とtensorflow)であまりにも限られているの理解が、私はグラフ(G)を作成し、引数としてこのグラフとのセッションを実行すると、アクセス権を与えるべきだと不思議な印象の下にありましたそのノードの中に明示的にアクセスする必要はありません。

啓発?

答えて

7

ノードpredictioncost、及びoptimizerメソッド__init__で作成したローカル変数であり、それらは、法launchGでアクセスすることができません。

最も簡単な修正があなたのクラスGraphの属性として宣言することです:

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      self.prediction = ... 
      self.cost  = ... 
      self.optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(self.optimizer, feed_dict) 
      loss = sess.run(self.cost, feed_dict) 
      ... 
     return variables 

あなたはまたgraph.get_tensor_by_namegraph.get_operation_by_nameとその正確な名前を使用して、グラフのノードを取得することができます。

関連する問題