2017-01-14 6 views
0

カスタムレイヤーを書きたいと思います。ここでは、実行の間に変数をメモリに保持できます。私はm.predictを実行すると は例えば、ケラスカスタムレイヤーの永続変数

class MyLayer(Layer): 
def __init__(self, out_dim = 51, **kwargs): 
    self.out_dim = out_dim 
    super(MyLayer, self).__init__(**kwargs) 

def build(self, input_shape): 
    a = 0.0 
    self.persistent_variable = K.variable(a) 
    self.built = True 

def get_output_shape_for(self, input_shape): 
    return (input_shape[0], 1) 

def call(self, x, mask=None): 
    a = K.eval(self.persistent_variable) + 1 
    K.set_value(self.persistent_variable, a) 
    return self.persistent_variable 

m = Sequential() 
m.add(MyLayer(input_shape=(1,))) 

は、私はpersistent_variableが更新されますことを期待し、インクリメント値を印刷します。それは常に私の質問がある0

# Dummy input 
x = np.zeros(1) 

m.predict(x, batch_size=1) 

印刷しますよう は、しかし、それは、どのように私はpersistent_variable増分を作り、m.predict

感謝のすべての実行後に保存しないに見え、 のNaveen

答えて

1

トリックということです呼び出し関数にself.add_update(...)を呼び出して、モデルが評価されるたびに呼び出される関数を登録する必要があります(ステートフル言語のソースコードを掘り下げて見つけました)。 self.stateful = Trueを実行すると、トレーニングと予測のコールごとにカスタム更新機能が呼び出されます。それ以外の場合は、トレーニング中に呼び出されます。

import keras.backend as K 
import numpy as np 
from keras.engine.topology import Layer 

class CounterLayer(Layer): 
    def __init__(self, stateful=False,**kwargs): 
    self.stateful = stateful # True means it will increment counter on predict and train, false means it will only increment counter on train 
    super(CounterLayer, self).__init__(**kwargs) 


    def build(self, input_shape): 
    # Define variables in build 
    self.count = K.variable(0, name="count") 
    super(CounterLayer, self).build(input_shape) 

    def call(self, x, mask=None): 
    updates = [] 
    # The format is (variable, value setting to) 
    # So this says 
    # self.pos = self.pos + 1 
    updates.append((self.count, self.count+1)) 

    # You can append more updates to this list or call add_update more 
    # times if you want 

    # Add our custom update 

    # We stick x here so it calls our update function every time our layer 
    # is given a new x 
    self.add_update(updates, x) 

    # This will be an identity layer but keras gets mad for some reason 
    # if you just output x so we'll multiply it by 1 so it thinks it is a 
    # "new variable" 
    return self.count 
    # in newer keras versions you might need to name this compute_output_shape instead 
    def get_output_shape_for(self, input_shape): 
    # We will just return our count as an array ([[count]]) 
    return (1,1) 

    def reset_states(self): 
    self.count.set_value(0) 

使用例:たとえば

from keras.layers import Input 
from keras.models import Model 
from keras.optimizers import RMSprop 
inputLayer = Input(shape=(10,)) 
counter = CounterLayer() # Don't update on predict 
# counter = CounterLayer(stateful=True) # This will update each time you call predict 
counterLayer = counter(inputLayer) 
model = Model(input=inputLayer, output=counterLayer) 
optimizer = RMSprop(lr=0.001) 
model.compile(loss="mse", optimizer=optimizer) 


# See the value of our counter 
print counter.count.get_value() 

# This won't actually train anything but each epoch will update our counter 

# Note that if you say have a batch size of 5, update will be called 5 times per epoch 
model.fit(np.zeros([1, 10]), np.array([0]), batch_size=1, nb_epoch=5) 

# The value of our counter has now changed 
print counter.count.get_value() 

model.predict(np.zeros([1, 10])) 

# If we did stateful=False, this didn't change, otherwise it did 
print counter.count.get_value() 
+0

こんにちはPhylliida、 が正解のように見えます。しかし、それは時々動作しません。 'a = model.predict(np.random.rand(100、10)、batch_size = 1) print(a)' '[0.1]を実行しました。 。9. 10. 10. 11. ....] ' 何度か更新がありません。 –

+0

これは何らかの競合状態かもしれません。私は実際に申し訳ありませんが、他の誰かが知っているかどうかを見るために待つことができます – Phylliida

+1

あなたは正しいです。ケラスに競合状態が存在する可能性があります。 'CounterLayer'の後に' RepeatVector'レイヤーを追加しました。 –

関連する問題