1
の使用方法はtf.scan
です。私は最小限の作業例を書いています(下記のコードを参照)。しかし、関数Model._step()
が呼び出されるたびに、それは計算グラフの別のコピーを作成していませんか?そうでない場合は、どうしてですか?この関数は毎回新しいTensorFlowグラフを作成していますか?
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # to avoid TF suggesting SSE4.2, AVX etc...
class Model():
def __init__(self):
self._inputs = tf.placeholder(shape=[None], dtype=tf.float32)
self._predictions = self._compute_predictions()
def _step(self, old_state, new_input):
# ---- In here I will write a much more complex graph ----
return old_state + new_input
def _compute_predictions(self):
return tf.scan(self._step, self._inputs, initializer = tf.Variable(0.0))
@property
def predictions(self):
return self._predictions
@property
def inputs(self):
return self._inputs
def test(sess, model):
sess.run(tf.global_variables_initializer())
print(sess.run(model.predictions, {model.inputs: [1.0, 2.0, 3.0, 4.0]}))
test(tf.Session(), Model())
もちろんこれは最小限の例ですから、私はもっと複雑なグラフが必要です。
ありがとうございました。私は 'tf.scan()'が 'Model.inputs'の各要素に対して' Model._step() 'を呼び出すと思っていました。 – Ziofil