同じ論理コードを使用してテンソルフローとnumpyを比較しています。tf.where()はnp.where()と同じ結果を出力しません。
tf.whereを実装すると、np.whereと同じ結果が得られません
以下のコードまたは使用方法の問題は何ですか?
データ:
X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0)
tensorflow tf.whereのおもちゃコード:
X = tf.placeholder(dtype=tf.int32, shape=[10, 10])
with tf.Session() as sess:
print(sess.run(tf.where(X > 5, tf.zeros([10, 10], dtype=tf.int32),
X), feed_dict={X: X_batch}))
numpyのnp.whereのおもちゃコード:
np.where(X_batch > 5, np.zeros([10,10]), X_batch)
コードいくつかのタイプミスがありました。修正しました
X_batchマトリックスを投稿することもできます –
@NipunWijerathneああ、更新しました。ごめんなさい。 –