kerasのサンプルディレクトリには、MNISTデータで学習する、積み上げたwhat-whereオートエンコーダ(SWWAE)の軽量版が含まれています。 (https://github.com/fchollet/keras/blob/master/examples/mnist_swwae.py)kerasのmaxpoolingポジションを得るために合計トリックのグラディエントはどのように機能しますか?
オリジナルのSWWAE論文では、著者はソフト関数を使用して何をどこで計算するのかを計算します。しかし、ケラスの実装では、これらの場所を取得するためにトリックを使用します。私はこのトリックを理解したいと思います。
ここにトリックのコードがあります。 y_prepoolは、M×Nの行列であり、y_postpoolがM/2×N/2行列である
def getwhere(x):
''' Calculate the 'where' mask that contains switches indicating which
index contained the max value when MaxPool2D was applied. Using the
gradient of the sum is a nice trick to keep everything high level.'''
y_prepool, y_postpool = x
return K.gradients(K.sum(y_postpool), y_prepool) # How exactly does this line work?
は(サイズ2つの画素の正規プーリングを想定することができます)。
私はgetwhere()の出力が釘のベッドであり、釘がmaxの位置を示していることを確認しました(もしあればローカルargmax)。
この「トリック」を使用してgetwhereがどのように動作するかを示す小さな例を作成できますか?
@lejotそれはdownvoted(それは素晴らしい書き込み)にもかかわらず、これを書く時間をとってくれてありがとう。 1つの質問、私はかなりK.sum()を理解していません。私はまったくフォローしていません、「今、「適切なマックスプール」のために....「別の方法を説明できますか? –
私が説明した内容は、単一のmaxpoolウィンドウで機能します。さて、あなたはこれらの多くを持っています。そして、合計の勾配が勾配の合計であるので、それは多くのそのような「呼び出し」を共同して表現する自然な方法です。また、K.sum()を実行する代わりに、それぞれの最大プール領域のグラデーションを個別に求めてから、すべてのマスクを追加することもできます(これは最終的に1より大きい値になります)。これはもっと明確ですか? – lejlot
はい、ありがとうございます! –