2016-07-13 39 views
7

この Paperに導入された「ガイド付きバックプロパゲーション」のテクニックをTensorFlowに実装したいと考えています。これは recipeに記載されています。TensorFlowのガイド付きバックプロパゲーション

これは、たとえば入力wrtのグラジエントを計算するときに計算されることを意味します。 NNの出力は、すべてのRELUユニットで計算されたグラジエントをに変更する必要があります。具体的には、これらのユニット上の逆伝播された信号は、この技法を働かせるために、ゼロで閾値処理されなければならない。換言すれば、負であるRELUの偏導関数は無視されなければならない。

これらの勾配計算をのテスト例にのみ適用したいと考えている、つまりモデルのパラメータを更新したくないのですが、どうすればよいですか?

私がこれまでに(失敗した)二つのことを試してみました:

  1. を使用 tf.py_funcそれはグラム経由勾配操作だ再定義する資格があるRELU、私の簡単なnumpyのバージョンをラップします.gradient_override_mapコンテキストマネージャ。

  2. BackPropの前後の値を集め、Relusからのものにしきい値を適用します。

現在のところ私が持っていないTFの内部についての知識が必要なので、両方の方法で失敗しました。

誰か他のルートを示唆したり、コードをスケッチしたりできますか?

ありがとうございます。

答えて

5

tf.gradientsには、この目的で使用できるパラメータgrad_ysがあります。次のようにネットワークが一つだけrelu層を有しているとします

before_relu = f1(inputs, params) 
after_relu = tf.nn.relu(before_relu) 
loss = f2(after_relu, params, targets) 

まず、after_reluまでの導関数を計算します。

Dafter_relu = tf.gradients(loss, after_relu)[0] 

次に、送信するグラデーションをしきい値に設定します。

Dafter_relu_thresholded = tf.select(Dafter_relu < 0.0, 0.0, Dafter_relu) 

paramsへの実際の勾配w.rtを計算します。

Dparams = tf.gradients(after_relu, params, grad_ys=Dafter_relu_thresholded) 

あなたは簡単に多くのrelu層とネットワークのために、この同じ方法を拡張することができます。

+0

こんにちはKaveman、感謝のために多くの迅速な返信。あなたの最後のtf.gradientsコールでは、最初の引数としてDafter_reluを意図的に渡していますか? – Peter

+1

また、私はまだ混乱しているwrtです。 RELU要素を持つ任意のNNで動作する方法で、多くのレイヤーを持つネットワーク上でこれを一般化する方法。この部分では、各RELU要素の入出力をすべてトレースする必要はなく、以前に説明したロジックを '連鎖'する必要はありませんか?ありがとう。 – Peter

+0

@Peter、申し訳ありませんが、それはタイプミスでした。 'tf.gradients'への2回目の呼び出しは' after_relu' w.r.tです。 'params'。 – keveman

5

ops.RegisterGradienttf.Graph.gradient_override_mapの方が良いソリューション(アプローチ1)です。これらは、あらかじめ定義されたOpの勾配計算をオーバーライドします。 Pythonコードのみを使用してgradient_override_mapコンテキスト内のRelu。https://gist.github.com/falcondai/561d5eec7fed9ebf48751d124a77b087

更新

@ops.RegisterGradient("GuidedRelu") 
def _GuidedReluGrad(op, grad): 
    return tf.where(0. < grad, gen_nn_ops._relu_grad(grad, op.outputs[0]), tf.zeros(grad.get_shape())) 

... 
with g.gradient_override_map({'Relu': 'GuidedRelu'}): 
    y = tf.nn.relu(x) 

は、ここではガイド付きreluの完全な実装例であるTensorflowに> = 1.0、tf.selecttf.whereに変更されます。それに応じてスニペットを更新しました。 (ありがとう@私の注意にこれをもたらすためのsbond :)

+2

relu op *を含むグラフ構築を 'gradient_override_map'コンテキストの中にラップする必要があることに注意してください。 – Falcon

+1

ありがとう、@ファルコン、これはうまくいきます。 TFバージョン1.2を使用しているので、tf.selectをtf.whereで置き換える必要もありました。 – sbond

+0

@sbond更新をお寄せいただきありがとうございます。私はあなたのコメントを含めるために自分の投稿を編集しました。 – Falcon

関連する問題