2017-10-22 22 views
1

私は、入力勾配を2倍にするカスタム勾配計算機能を持っています。TensorFlowカスタム勾配

import tensorflow as tf 

@tf.RegisterGradient("CustomSquare") 
def _custom_square_grad(op, grad): 
    return grad*2.0 

c = tf.constant(3.) 

s1 = tf.square(c) 
grad1 = tf.gradients(s1, c)[0] 

g = tf.get_default_graph() 
with g.gradient_override_map({"Square": "CustomSquare"}): 
    s2 = tf.square(c) 
    grad2 = tf.gradients(s2, c)[0] 

with tf.Session() as sess: 
    print(sess.run([c, s1, grad1])) 
    print(sess.run([c, s2, grad2])) 

私が得る結果は驚くべきものである:

[3.0, 9.0, 6.0] 
[3.0, 9.0, 2.0] 

私は2番目の結果は[3.0, 9.0, 12.0]であることを期待していました。私は何が欠けていますか?

ありがとうございました。要するに

答えて

1

、_custom_square_gradの正しいバージョンは次のようになります。

@tf.RegisterGradient("CustomSquare")            
def _custom_square_grad(op, grad):            
    x = op.inputs[0]                
    return 2.0 * (grad * 2.0 * x) 

コードを理解するために、あなたはどのようにgradient作品を知る必要があります。 tf.RegisterGradientを定義すると、グラデーションを出力から入力にBACK-PROPAGATEすることになっています。あなたがカスタマイズした勾配関数に勾配を倍増したいので、あなたは、単にgrad_x = 2.0 * (grad_y * 2.0 * x)に変更することができ

# Given y = tf.square(x) => y' = 2x 
grad_x = grad_y * 2.0 * x 

tf.squreの場合、デフォルトの勾配関数は、このようなものです。