2017-02-24 1 views
0

に基づいて1次元テンソルを変更します。テンソルがある場合たとえば、:私はtensorflowにおける1Dテンソルの各要素に条件を適用し、したがって入力テンソルを変更する条件

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0])

私は、各要素が0.1よりも大きいか否かを確認したいです。もしそうであれば、その要素は他の10になります。どのように同じtensorflowに行うことができますか?私はPythonの関数を記述し、その後py_funcを使用してtensorflowでそれを使用して、今まで試してみましたが、それが動作していない何

。コードの下を参照してください -

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0]) 
with tf.Session() as sess: 
    y = tf.py_func(round_with_threshold, [y_true], tf.float32) 
    y.eval() 

def round_with_threshold(arr): 
threshold = 0.1 
rounded_arr = np.zeros(arr.shape[0]) 
for i in range(arr.shape[0]): 
    if arr[i]>=threshold: 
     rounded_arr[i] = 1 
    else: 
     rounded_arr[i] = 0 
return rounded_arr 

は任意のPython関数を記述することなくtensorflowでこれを行うには、それは可能ですか?

答えて

0

あなたが行うことができ、以下:

import tensorflow as tf 

y_true = tf.Variable([0.0, 0.3, 0.0, 0.4, 0.0, 0.1, 0.0, 0.0]) 
comp_op = tf.greater(y_true, 0.1) # returns boolean tensor 
cast_op = tf.cast(comp_op, tf.int32) # casts boolean tensor into int32 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    print(sess.run(cast_op)) 

プリント:[0 1 0 1 0 0 0 0]

関連する問題