2017-10-25 13 views
-1

同じ論理コードを使用してテンソルフローとnumpyを比較しています。tf.where()はnp.where()と同じ結果を出力しません。

tf.whereを実装すると、np.whereと同じ結果が得られません
以下のコードまたは使用方法の問題は何ですか?

データ:

X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0) 

tensorflow tf.whereのおもちゃコード:

X = tf.placeholder(dtype=tf.int32, shape=[10, 10]) 

with tf.Session() as sess: 
    print(sess.run(tf.where(X > 5, tf.zeros([10, 10], dtype=tf.int32), 
          X), feed_dict={X: X_batch})) 

numpyのnp.whereのおもちゃコード:

np.where(X_batch > 5, np.zeros([10,10]), X_batch) 

コードいくつかのタイプミスがありました。修正しました

+0

X_batchマトリックスを投稿することもできます –

+0

@NipunWijerathneああ、更新しました。ごめんなさい。 –

答えて

0

私はコードを編集しました。 tf.where()への入力は、np.where()と同じにする必要があります。したがって、tf.where()への引数は、np.where()メソッドの引数として、10 * 10の行列とx_batchをゼロとして与える必要があります。

import tensorflow as tf 
import numpy as np 


X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0) 

#print(X_batch) 

X = tf.placeholder(dtype=tf.int32, shape=[10, 10]) 

with tf.Session() as sess: 
    print(sess.run(tf.where(X > 5, tf.fill([10, 10], 0), 
          X), feed_dict={X: X_batch})) 

これが役に立ちます。

+1

どこに問題があったのか、それに対する解決策は何かを教えてください。 – Evert

+0

ありがとうございました。それは単にタイプミスだった。 –

+0

私は私の答えを編集 –

関連する問題