2017-03-14 5 views
2

最近、が読み込まれ、KL発散の損失に値が依存する変数を乗算する「ウォームアップ」(WU)というプロセスが導入されました。エポックの数(それは0から1まで直線的に進化する)Variationnal自動エンコーダ:Kerasでウォームアップを実装

これはそれを行うための良い方法である場合、私は思っていた:

beta = K.variable(value=0.0) 

def vae_loss(x, x_decoded_mean): 
    # cross entropy 
    xent_loss = K.mean(objectives.categorical_crossentropy(x, x_decoded_mean)) 

    # kl divergence 
    for k in range(n_sample): 
     epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., 
           std=1.0) # used for every z_i sampling 
     # Sample several layers of latent variables 
     for mean, var in zip(means, variances): 
      z_ = mean + K.exp(K.log(var)/2) * epsilon 

      # build z 
      try: 
       z = tf.concat([z, z_], -1) 
      except NameError: 
       z = z_ 
      except TypeError: 
       z = z_ 

      # sum loss (using a MC approximation) 
      try: 
       loss += K.sum(log_normal2(z_, mean, K.log(var)), -1) 
      except NameError: 
       loss = K.sum(log_normal2(z_, mean, K.log(var)), -1) 
     print("z", z) 
     loss -= K.sum(log_stdnormal(z) , -1) 
     z = None 
    kl_loss = loss/n_sample 
    print('kl loss:', kl_loss) 

    # result 
    result = beta*kl_loss + xent_loss 
    return result 

# define callback to change the value of beta at each epoch 
def warmup(epoch): 
    value = (epoch/10.0) * (epoch <= 10.0) + 1.0 * (epoch > 10.0) 
    print("beta:", value) 
    beta = K.variable(value=value) 

from keras.callbacks import LambdaCallback 
wu_cb = LambdaCallback(on_epoch_end=lambda epoch, log: warmup(epoch)) 


# train model 
vae.fit(
    padded_X_train[:last_train,:,:], 
    padded_X_train[:last_train,:,:], 
    batch_size=batch_size, 
    nb_epoch=nb_epoch, 
    verbose=0, 
    callbacks=[tb, wu_cb], 
    validation_data=(padded_X_test[:last_test,:,:], padded_X_test[:last_test,:,:]) 
) 
+0

これはうまくいくようです。私は答え[ここ](https://github.com/fchollet/keras/issues/2595)は少しきれいかもしれないと思いますが、あなたのコールバックをベータのパラメータにしてください。あなたが上でやっているようにグローバル。 – mobeets

+0

私はそのスレッドを見ていないありがとうございます – sbaur

答えて

1

これは動作しません。私はそれが動作していない理由を正確に把握するためにそれをテストしました。覚えておくべき重要なことは、Kerasがトレーニングの初めに静的グラフを作成することです。

したがって、vae_loss関数は、損失テンソルを作成するために1回だけ呼び出されます。つまり、損失が計算されるたびに変数betaへの参照が同じままになります。ただし、warmup関数はベータを新しいK.variableに再割り当てします。したがって、損失を計算するために使用されるbetaは、更新される値と異なるbetaであり、値は常に0になります。

これは簡単な修正です。

beta = K.variable(value=value)

へ:ちょうどあなたのwarmupコールバックで、この行を変更

K.set_value(beta, value)

betaの実際の値は、むしろ新しい変数を作成するよりも、「場所に」更新されます。この方法で、と損失は​​適切に再計算されます。

+0

ありがとう! :)私はすぐにそれを試してみる – sbaur

関連する問題