2017-03-01 6 views
0

私はWGANを実装しており、ウェイト変数をクリップする必要があります。clip_by_weightのテンソル形式のケラス層の重み変数へのアクセス方法は?

私は現在、を使用しています。TensorflowKerasを高レベルAPIとして使用しています。したがって、変数を手動で作成して初期化することを避けるために、Kerasでレイヤーを構築します。

問題は、体重可変をクリップするWGANの必要性です。これは、一度それらの体重可変テンソルを得たらtf.clip_by_value(x, v0, v1)を使用して行うことができますが、私はそれらを安全に取得する方法はわかりません。

1つの可能な解決策は、おそらくtf.get_collection()を使用して、すべての訓練可能な変数を得ることです。しかし、私はどのように体重バイアスなし変数なしで取得する方法がわからない。

クリップ操作をする必要があるので、私はnumpyのAPIでそれらをクリップしてlayer.set_weights()を使用して、それらを設定するが、これはCPU-GPU法人が必要になる場合があり、そして良い選択ではないかもしれないことができますが、別の解決策はlayer.get_weights()ですが、それはnumpyの配列を取得します列車ステップごとに実行する必要があります。

正確な私はTFの下位レベルのAPIやTensorBoardから得ることができる変数名を使用して直接アクセスしますが、Kerasの命名規則は安定しているとは限りません。

clip_by_valueをTensorflowとKerasのWにのみ実行するクリーンな方法はありますか?

答えて

3

パラメータ(here)クラスを使用して、パラメータに新しい制約を実装できます。

ここでは、重み付きクリップを簡単に実装してモデルで使用する方法を示します。

from keras.constraints import Constraint 
from keras import backend as K 

class WeightClip(Constraint): 
    '''Clips the weights incident to each hidden unit to be inside a range 
    ''' 
    def __init__(self, c=2): 
     self.c = c 

    def __call__(self, p): 
     return K.clip(p, -self.c, self.c) 

    def get_config(self): 
     return {'name': self.__class__.__name__, 
       'c': self.c} 

import numpy as np 
from keras.models import Sequential 
from keras.layers import Dense 

model = Sequential() 
model.add(Dense(30, input_dim=100, W_constraint = WeightClip(2))) 
model.add(Dense(1)) 

model.compile(loss='mse', optimizer='rmsprop') 

X = np.random.random((1000,100)) 
Y = np.random.random((1000,1)) 

model.fit(X,Y) 

私は上記のコードの実行をテストしましたが、制約の有効性は検証していません。トレーニング後にモデルウエイトをmodel.get_weights()またはmodel.layers[idx].get_weights()で取得し、その制約を守っているかどうかをチェックすることでそうすることができます。

注:制約は、すべてのモデルの重みに追加されません..しかし、ちょうどその使用される特定の層の重みにもW_constraintb(バイアス)のparam

+0

Wのparamとb_constraintに制約を加算ありがとうあなたの返事のために。私はy = ax + bのおもちゃの問題でそれを試しました、それは完璧に動作します! – soar0x48

+0

あなたは大歓迎です。 – indraforyou

+0

Tensorflowで実装する際に問題があります。 model.fitを使用すると、完全に動作しますが、TFコマンドでsess.run(tran_step、...)を使用すると、制約が機能しないようです。どのようにそれを修正するための任意のアイデア? – soar0x48

関連する問題