2017-10-22 14 views
0

私は簡単な質問があります。例として、ブール関数を実行する単層パーセプトロンをモデル化するテンソルフロープログラムを作成しているとします。そのパーセプトロンを含むクラスを作成している場合は、クラスのインスタンス化ごとにtf.graphを使用して別のグラフを作成する必要がありますか?テンソルフローとクラスのintantiations。インスタンス化ごとに個別のグラフが必要ですか?

I.e.私はブール関数AND、ORおよびNANDをモデル化し、私はクラスパーセプトロン書く場合: ...

は私がするように、各クラスのinstantionためのinit方法でtf.graph()を実行する必要があります彼らは同じ名前空間を使用していませんか?または、それぞれの操作の種類ごとにスコープを定義して、3つの操作をすべて実行するために1つのクラスのみをインスタンス化することはできますか?

おかげ

答えて

1

これは、原理的には両方のソリューションが正しい、あなたはあなたの問題を与えたい全体的なアーキテクチャに依存します。

相互作用するモデルが必要であることが予想される場合(たとえば、両方のモデルを必要とするコスト関数/推論を作成する場合)、それらをすべて同じグラフに入れることをお勧めします。そうでなければ、抽出する必要がありますそれらを相互作用させるためにセッションを通して

モデルがお互いに話すことがないことが分かっている場合は、別々のグラフでモデルを構築し、不安定なやりとりを避けるようにしてください(たとえば、モデルのコスト関数を最小化するこの第二の溶液のために別のモデルに属している変数)

、私は(私は、これは良い習慣であるかどうかわからないです)、次のようにクラスを構築したい:

class MyModel(object): 
    def __init__(self, **kwargs): 
     self.g = tf.Graph() 
     self.sess = None 

     with self.g.as_default(): 
      self.build_graph() 

    def build_graph(self): 
     # here all the logic to build the graph. 
     # usually you store the outputs of the model in attributes 
     # self.input_data_placeholder 
     # self.target_data_placeholder 
     # self.prediction 
     # self.cost_tensor 
     # self.learning_rate_placeholder 
     # self.train_op 
     # self.accuracy_tensor 
     # self.init_op 

    def open_session(self): 
     if self.sess is None: 
      sess = tf.Session(graph=self.g) 
      sess.run(init_op) 
     else: 
      sess = self.sess 
     return sess 

    def predict(self, data): 
     sess = self.open_session() 
     pred = sess.run(
      self.prediction, 
      feed_dict={self.input_data_placeholder: data} 
     ) 
     return pred 

    def update(self, data, target, learning_rate): 
     sess = self.open_session() 
     cost, _ = sess.run(
      [self.cost_tensor, self.train_op], 
      feed_dict={self.input_data_placeholder: data, 
         self.target_data_placeholder: target, 
         self.learning_rate_placeholder: learning_rate 
         } 
     return cost 

今クラスが隠れていますすべてのテンソルフロー関連のもの、そしてnumpy.arraysでそれと通信する

関連する問題