2016-09-27 3 views
0

tf.while_loop()でテンソルの中間値を取得する必要がありますが、最後に返された値のみが返されます。TensorFlow:tf.while_loop()の変数の中間値を取得するには?

たとえば、私は3ページを持ち、その次元は3 * 2 * 4の変数xを持っています。今度は、各ページを1回フェッチし、各ページの合計、ページ合計、平均値、最大値、最小値を計算します。次に、条件とボディ関数を定義し、必要な結果を計算するためにtf.while_loop()を使いたいとします。ソースコードは以下の通りです。

import tensorflow as tf  
x = tf.constant([[[41, 8, 48, 82], 
         [9, 56, 67, 23]], 
        [[95, 89, 44, 54], 
         [11, 33, 29, 1]], 
        [[34, 9, 5, 70], 
         [14, 35, 18, 17]]], dtype=tf.int32) 

def cond(out, count, x): 
    return count < 3 

def body(out, count, x): 
    outTemp = tf.slice(x, [count, 0, 0], [1, -1, -1]) 
    count += 1 
    outPack = tf.unpack(out) 
    outPack[0] += tf.reduce_sum(outTemp) 
    outPack[1] = tf.reduce_sum(outTemp) 
    outPack[2] = tf.reduce_mean(outTemp) 
    outPack[3] = tf.reduce_max(outTemp) 
    outPack[4] = tf.reduce_min(outTemp) 
    out = tf.pack(outPack) 
    return out, count, x 

out = tf.Variable(tf.constant([0, 0, 0, 0, 0])) # total sum, page sum, mean, max, min 
count = tf.Variable(tf.constant(0)) 
result = tf.while_loop(cond, body, [out, count, x]) 

init = tf.initialize_all_variables() 
with tf.Session() as sess: 
    sess.run(init) 
    print(sess.run(x)) 
    print(sess.run(result)[0]) 

プログラムを実行すると、前回の戻り値のみが返され、最後のページの結果しか取得できません。 質問は、どのように各ページの結果を得ることができますか?tf.while_loop()から中間値を取得する方法はありますか?

ありがとうございます。

答えて

0

任意の変数の「中間値」を得るには、前述の変数を評価する際に関連するメッセージを出力するという副作用を伴う恒等操作であるtf.Print演算を単純に使用できます。

例として、

x = tf.Print(x, [x], "Value of x is: ") 

はあなたが値を報告させたい任意の行に配置することができます。

+0

ありがとうございます。私は知りたいです、あなたはtf.Print()をどこに置くでしょうか?実際、関数body()に入れると、結果は出力されません。 tf.while_loop()関数は特別な方法で実行されたので、 – Kongsea

+0

それは変です。ここにあなたの体の機能を貼り付けることができますか? – cberkay

+0

こんにちは、私は質問をしたときに私のすべてのコードが貼り付けられており、あなたは質問のセクションで見ることができます。ありがとうございました。 – Kongsea

関連する問題