2017-05-04 12 views
1

私はモデルの2つのクラス、xyを定義しました。Tensorflow r-1.0を使用して、あるモデルの出力を別のモデルの入力とするにはどうすればよいですか?

class x(): 
    def __init__(self, x_inp1, x_inp2): 
     # do sth... 

    def step(self, session, encoder_inputs): 
     input_feed = {} 
     for l in range(encoder_size): 
      input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 
     ... 
     output_feed = [x_output] 
     return session.run(x_output) 

class y(): 
    def __init__(self, y_inp1, y_inp2): 
     # do sth... 

    def step(self, encoder_inputs): 
     input_feed = {} 
     for l in range(encoder_size): 
      input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 
     ... 

これらの機能は非常に似ています。そして、それらをグループ化する別のクラスを定義します。

class gp(): 
    def __init__(self, x_inp1, x_inp2, y_inp1, y_inp2): 
     with tf.variable_scope('x'): 
       self.x_model = x(x_inp1, x_inp2) 
     with tf.variable_scope('y'): 
       self.y_model = y(y_inp1, y_inp2) 
    def step(self, session, encoder_inputs): 
     x_output = self.x_model.step(session, encoder_inputs) 
     y_output = self.y_model.step(session, x_output) 
     ... 

y_modelは、入力としてx_modelの出力を取ることに注意してください。そして、私はmain機能でgp()を実行します。

with tf.Session() as sess: 
    gp_m = gp(x_inp1, x_inp2, y_inp1, y_inp2) 
    gp_m.step(sess, x_inp1, x_inp2, y_inp1, y_inp2) 

x_output = self.x_model.step(encoder_inputs)を実行した後とy_output = self.y_model.step(x_output)を行うために始める、私は、このようなエラーを得た:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'x/encoder0' with dtype int32 
[[Node: x/encoder0 = Placeholder[dtype=DT_INT32, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

でもx_modelにこのエラーポイントを注意してくださいステップ機能が終了しました。私はどのようにエラーなしでy_modelの入力としてx_modelの出力を使用することができますか?前もって感謝します!

答えて

0

session.runへの呼び出しをstepの機能外にすることを延期する必要があります。ここで問題となるのは、Yがグラフ内で接続されているため、XをトリガーするのがXを実行しようとしていることです。

代わりに、プログラムのグラフ作成とグラフの実行段階を完全に分ける方がよいので、いつどのプレースホルダーを指定するかを知ることができます。

+0

あなたの答えをありがとう!しかし、私はまだ少し混乱していることを申し訳なく思っています。 'session.run'への呼び出しをWHICH' step'関数の外側にすることを延期すべきですか? グラフを完全に分割する方法を教えてください。どうもありがとう!!! – user5779223

+0

グラフを作成する(つまり、ほとんどのtf関数を呼び出します)、実行するテンソルを返し、別の場所でsession.runを呼び出します。コードを構造化する良い方法を知るには、tf.estimator.Estimatorを見てみてください。 –

関連する問題