2017-07-17 7 views
2

私はKerasでGANを訓練したいと思います。私の最終目標はBEGANですが、私は最も簡単なものから始めます。 を理解するにはのウェイトを適切に固定する必要があります。それが私が苦労していることです。Kerasでモデルをコンパイルした後に動的に凍結する方法は?

発電機の訓練時間中、弁別器の重量が更新されないことがあります。 フリーズフリーダイヤル交互に訓練ジェネレータとdiscriminatorのためのdiscriminatorをしたいと思います。問題は、trainableパラメータをディスクリミネータモデルモデルまたはfalseのfalseに設定することです。つまり、そのウェイトはトレーニングするモデル(および更新するウェイト)を停止しません。一方、をtrainableに設定してモデルをコンパイルすると、はFalseになり、重みはになります。になります。私は、各反復後にモデルをコンパイルすることはできません。それは、全体のトレーニングという考えを否定するからです。

この問題のために、多くのKerasの実装が盗まれているか、古いバージョンなどの直感的ではないトリックのために動作しているようです。

答えて

1

は、私は数ヶ月前にこのサンプルコードを試してみたし、それが働いた: https://github.com/fchollet/keras/blob/master/examples/mnist_acgan.py

それはGANの最も単純な形ではないが、限り、私は思い出したように、それは分類の損失を削除するにはあまりにも難しいことではありませんし、モデルをGANに変換します。

discriminatorのtrainableプロパティをオン/オフして再コンパイルする必要はありません。 trainable=True(コードではdiscriminator)とtrainable=False(コードでは)という2つのモデルオブジェクトを作成してコンパイルするだけです。

ディスクリミネータを更新する場合は、discriminator.train_on_batch()に電話してください。発電機を更新する場合は、combined.train_on_batch()に電話してください。

0

tf.stop_gradientを条件付きで凍結することができますか?

+0

'tf.stop_gradient'はグラデーションの流れを止めています。これは達成したいものではありません。グラデーションフローを作成し、ウェイトのグラデーションを計算したいが、更新操作は実行しない。 –

+0

次に、いつでも重みをフリーズ/フリーズ解除するのではなく、更新したい変数のリストを明示的にテンソルフロー更新オペレーションに渡す方が良いかもしれません。 –

+0

あなたは正しいですが、それはTensorflowの解決策であり、Kerasはそれを許可していません。あなたはモデル抽象化を持ち、主に 'fit'メソッドと' train_on_batch'メソッドを持っています。それだけです。純粋なケラスに解決策がない場合は、私はTensorflowに切り替えます。 –

0

おそらく、あなたの敵対的なネット(ジェネレータとディスクリミネータ)が「モデル」に書かれているかもしれません。 しかし、d.trainable = Falseを設定しても、独立したdネットは訓練できないように設定されますが、敵対的ネット全体のdはまだ訓練可能です。

d.trainable = Falseと設定した後にd_on_g.summary()を使用することができます。あなたは何を意味するのか知っているでしょう(訓練可能な変数に注意してください)。

関連する問題