2017-07-11 20 views
0

tf.where()が出力する配列シェイプを回避する方法があるのだろうかと思います。たとえば、ここに私が実行しようとしていますコードです。しかし、事実によるエラーの方法は、TFことはありtf.where()シェイプの出力への回避策

[[[ 1,100] 
    [100, 2] 

[[100, 3] 
    [ 4,100]]] 

:私は、出力は次のようになり期待何

import tensorflow as tf 

with tf.Session() as sess: 
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]).eval() 
    c = tf.where(tf.equal(a,0)).eval() 
    c = tf.multiply(100,c).eval() 
    c = tf.add(a,c) 

    print(c.eval()) 

です。ここで()は私のコードを2x2x2ではなく4x3テンソルとして出力します。効果的にすべての0を100に置き換えるために使用できる別のコマンドセットがありますか?このメソッドは、2-D配列に対して機能します。

答えて

0

テンソル100を作成し、それをwhere呼び出しで使用できます。 documentationによると、この関数は値を読み取るために2つのオプションのテンソルを取ります。

with tf.Session() as sess: 
    a = tf.constant([[[1,0],[0,2]],[[0,3],[4,0]]]) 
    h = tf.multiply(tf.ones(a.shape, tf.int32), 100) 
    c = tf.where(tf.equal(a, 0), h, a) 

    print(c.eval())