2017-10-13 4 views
1

の使用方法はtf.scanです。私は最小限の作業例を書いています(下記のコードを参照)。しかし、関数Model._step()が呼び出されるたびに、それは計算グラフの別のコピーを作成していませんか?そうでない場合は、どうしてですか?この関数は毎回新しいTensorFlowグラフを作成していますか?

import tensorflow as tf 
import os 
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc... 

class Model(): 
    def __init__(self): 
     self._inputs = tf.placeholder(shape=[None], dtype=tf.float32) 
     self._predictions = self._compute_predictions() 

    def _step(self, old_state, new_input): 
     # ---- In here I will write a much more complex graph ---- 
     return old_state + new_input 

    def _compute_predictions(self): 
     return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0)) 

    @property 
    def predictions(self): 
     return self._predictions 

    @property 
    def inputs(self): 
     return self._inputs 

def test(sess, model): 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]})) 

test(tf.Session(), Model()) 

もちろんこれは最小限の例ですから、私はもっと複雑なグラフが必要です。

答えて

1

Model._step()メソッドは、Modelオブジェクトごとに1回だけ呼び出されます。 tf.while_loop()関数のように、tf.scan()関数は、指定された関数を1回呼び出すだけで、ループを含むグラフを作成し、同じグラフをループの繰り返しごとに使用します。

(あなたは多くのModelオブジェクトを構築する場合は、Modelオブジェクトを持っているとして、あなたは、グラフのコピーの数が同じになってしまいますので注意してください。)

+0

ありがとうございました。私は 'tf.scan()'が 'Model.inputs'の各要素に対して' Model._step() 'を呼び出すと思っていました。 – Ziofil

関連する問題