tf.while_loop()でテンソルの中間値を取得する必要がありますが、最後に返された値のみが返されます。TensorFlow:tf.while_loop()の変数の中間値を取得するには?
たとえば、私は3ページを持ち、その次元は3 * 2 * 4の変数xを持っています。今度は、各ページを1回フェッチし、各ページの合計、ページ合計、平均値、最大値、最小値を計算します。次に、条件とボディ関数を定義し、必要な結果を計算するためにtf.while_loop()を使いたいとします。ソースコードは以下の通りです。
import tensorflow as tf
x = tf.constant([[[41, 8, 48, 82],
[9, 56, 67, 23]],
[[95, 89, 44, 54],
[11, 33, 29, 1]],
[[34, 9, 5, 70],
[14, 35, 18, 17]]], dtype=tf.int32)
def cond(out, count, x):
return count < 3
def body(out, count, x):
outTemp = tf.slice(x, [count, 0, 0], [1, -1, -1])
count += 1
outPack = tf.unpack(out)
outPack[0] += tf.reduce_sum(outTemp)
outPack[1] = tf.reduce_sum(outTemp)
outPack[2] = tf.reduce_mean(outTemp)
outPack[3] = tf.reduce_max(outTemp)
outPack[4] = tf.reduce_min(outTemp)
out = tf.pack(outPack)
return out, count, x
out = tf.Variable(tf.constant([0, 0, 0, 0, 0])) # total sum, page sum, mean, max, min
count = tf.Variable(tf.constant(0))
result = tf.while_loop(cond, body, [out, count, x])
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print(sess.run(x))
print(sess.run(result)[0])
プログラムを実行すると、前回の戻り値のみが返され、最後のページの結果しか取得できません。 質問は、どのように各ページの結果を得ることができますか?tf.while_loop()から中間値を取得する方法はありますか?
ありがとうございます。
ありがとうございます。私は知りたいです、あなたはtf.Print()をどこに置くでしょうか?実際、関数body()に入れると、結果は出力されません。 tf.while_loop()関数は特別な方法で実行されたので、 – Kongsea
それは変です。ここにあなたの体の機能を貼り付けることができますか? – cberkay
こんにちは、私は質問をしたときに私のすべてのコードが貼り付けられており、あなたは質問のセクションで見ることができます。ありがとうございました。 – Kongsea