2016-08-01 3 views
5

私は、次のしているシンプルなプレースホルダ:Tensorflowのパルサをtf.cond内の関数に渡すにはどうすればいいですか?

x = tf.placeholder(tf.float32, shape=[1]) 
y = tf.placeholder(tf.float32, shape=[1]) 
z = tf.placeholder(tf.float32, shape=[1]) 

として定義された2つの機能f1f2あります

pred = tf.placeholder(tf.bool, shape=[1]) 
result = tf.cond(pred, f1(x,y), f2(y,z)) 

def fn1(a, b): 
    return tf.mul(a, b) 
def fn2(a, b): 
    return tf.add(a, b) 

は、今私はpredを条件に基づいて結果を計算するには、しかしそれは私にエラーfn1 and fn2 must be callableと言う。

fn1fn2を実行時に受信できるようにするにはどうすればよいですか?

sess.run(result, feed_dict={x:1,y:2,z:3,pred:True}) 

答えて

1

最も簡単な通話中、あなたの関数を定義するには、次のようになります:私は次のように呼び出したい

result = tf.cond(pred, lambda: tf.mul(a, b), lambda: tf.add(a, b)) 
9

あなたはラムダを使用して機能し、コードにパラメータを渡すことができますがようですベローズ。

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32) 
z = tf.placeholder(tf.float32) 

def fn1(a, b): 
    return tf.mul(a, b) 

def fn2(a, b): 
    return tf.add(a, b) 

pred = tf.placeholder(tf.bool) 
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) 

その後、あなたは次のようにそれを呼び出すことができます。

with tf.Session() as sess: 
    print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True}) 
    # The result is 2.0 
関連する問題