2017-06-06 7 views
-1

私は3つの配列X,YおよびZを持っています。 Zの対応する要素が真である場合、resXの要素を入れたいと思います。それ以外の場合は、Yの要素を入れます。ここで()は1〜2の位置引数をとりますが、3は与えられました

私はこのようにそれを実装:

X = tf.constant([[1, 2], [3, 4]]) 
Y = tf.constant([[5, 6], [7, 8]]) 
Z = tf.constant([[True, False], [False, True]], tf.bool) 
res = tf.where(Z, X, Y) 
print(res.eval()) 

しかし、私はこのエラーを取得しています:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given 

私はhereからtf.whereのdefinitonを見て、私の使い方は罰金です。

何が問題なのでしょうか?

+0

あなたは 'tf.where(Z、X = X、Y = Y)' – pramod

+0

あなたのコードはTensorFlow 1.0.1で正常に動作しますので、私は好奇心を試すことができます。これはTFのバージョンを使用していますか? – npf

答えて

1

TensorFlowの古いバージョンを使用していると思われます。

r0.10でtf.whereは2つの引数しか取らなかった。

tf.where(input, name=None)

https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops/sequence_comparison_and_indexing#where

+0

私は 'pip'でインストールしたので、おそらく' 0.8.0'を使用しています。 – octavian

+0

これは意味があります。おそらく最新のバージョンをインストールするべきです:https://www.tensorflow.org/install/ – npf

関連する問題