2016-01-04 3 views
12

は私がTrue秒数を取得するには、結果のshape[0]を使用することができるようにtf.whereは、True値の位置を返すことを理解しています。カウント数

しかし、これを試して使用すると、次元は不明です(これは実行時に計算する必要があるため意味があります)。ですから、私の質問は、ディメンションにアクセスして合計のような操作でそれを使用するにはどうすればいいのですか?

例えば:

myOtherTensor = tf.constant([[True, True], [False, True]]) 
myTensor = tf.where(myOtherTensor) 
myTensor.get_shape() #=> [None, 2] 
sum = 0 
sum += myTensor.get_shape().as_list()[0] # Well defined at runtime but considered None until then. 
+0

'myTensor'はどのようなものが見えますか? – erip

+0

@eripより明示的な例で更新しました。 –

+0

期待される出力は2ですか? – erip

答えて

24

あなたがそれらに合計を浮遊し計算するために値をキャストすることができます。 tf.reduce_sum(tf.cast(myOtherTensor, tf.float32))

場合は実際のユースケースに応じて、あなたはまた、行/列あたりの合計を計算することができますコールの縮尺を指定します。

+0

ありがとうございます! Pythonのブール同値性を忘れてしまった。 –

+0

キャストなしでこれを行う方法はありますか(および対応するメモリヒット)? –

4

ラファウの答えはほぼ確実にあなたのテンソルでtrue要素の数をカウントするための最も簡単な方法が、尋ねあなたの質問の他の部分である:[H]

OW私は次元をアクセスし、それを使用することができます合計のような操作で?

これを行うには、テンソルの実行時の値に作用するTensorFlowのshape-related operationsを使用します。例えば、tf.size(t)は、tに要素数を含むスカラーTensorを生成し、tf.shape(t)は、各次元にtのサイズを含む1D Tensorを生成します。非ゼロ値tf.count_nonzeroをカウントするtensorflow機能があり

myOtherTensor = tf.constant([[True, True], [False, True]]) 
myTensor = tf.where(myOtherTensor) 
countTrue = tf.shape(myTensor)[0] # Size of `myTensor` in the 0th dimension. 

sess = tf.Session() 
sum = sess.run(countTrue) 
+0

これは本当にありがとう!結果の谷に0をループすることは可能でしょうか?それは単一元素のテンソルであろうか?だから私はそれをxrange()に渡すことができませんでした(xrangeがTensorsを受け入れるように拡張されていない限り)?これが不明な場合、私は詳細についてそれについて新しい質問をします。 –

+0

ええと、私はあなたが何を求めているのかよく分かりません。あなたは 'myTensor'の寸法をループし、各寸法のサイズを使用したいですか?おそらく、これは新しい質問で最もよく対処されるだろう。 – mrry

+0

ええ、それはあまり言葉になっていませんでした。私は1つを書くでしょう。 –

0

:これらの演算子を使用して

は、あなたのプログラムは、のように書くこともできます。この関数は、axiskeep_dims引数も受け入れます。ここで

は簡単な例です:

import numpy as np 
import tensorflow as tf 
a = tf.constant(np.random.random(100)) 
with tf.Session() as sess: 
    print(sess.run(tf.count_nonzero(tf.greater(a, 0.5)))) 
2

私は、これはそれを行うための最も簡単な方法だと思います:

In [38]: myOtherTensor = tf.constant([[True, True], [False, True]]) 

In [39]: if_true = tf.count_nonzero(myOtherTensor) 

In [40]: sess.run(if_true) 
Out[40]: 3