2017-05-31 21 views
2

の列車選択分岐が、私はこれに似た構造を持つマルチ出力Kerasモデルを持っている:マルチ出力kerasモデル

s = some_shared_layers()(input) 
non_trainable1 = Dense(trainable=False) (s) 
non_trainable2 = Dense(trainable=False) (s) 
trainable = Dense() (s) 

model = Model(input, outputs=[non_trainable1, non_trainable2, trainable]) 

私のモデルは、最初のフォワード・パスを計算し、入力を操作することが第一2つの出力を使用しています。次に、別の順方向パスを計算して3番目の出力を取得します。私はtrainable層を訓練するmodel.fit()呼び出す必要がありますどのように

out1, out2,_ =model.predict(input_data) 
processed_data = foo(input_data, out1, out2) 
_,_, out3 = model.predict(processed_data) 

?他の出力の損失を除外すると、Kerasはwe will not be expecting any data to be passed to "non_trainable1" during trainingと警告し、それらを計算グラフから除外します。

このタスクのモデルを構成する良い方法はありますか?

答えて

0

私が正しく理解していれば、それらのレイヤーは必要ありません。実際には、2つのモデル、1つは予測のみ、もう1つはトレーニング用です。

非トレーニング可能:

model1 = Model(input, [non_trainable1, non_trainable2]) 
#model 1 doesn't need to be compiled, since you won't train it  

訓練可能:それらを使用して

model2 = Model(input, trainable) 
model2.compile(loss=onlyTheLossForTrainable)  

out1, out2 =model1.predict(input_data) 
processed_data = foo(input_data, out1, out2) 

model2.fit(processed_data, expected_outputs, ....)  
+0

私は現在、回避策としてそれを使用していますあなたが記述したようにコードを実行すると動作します。しかし、 'ImageDataGenerator'の中でもっと複雑な方法で' foo'という処理関数を使用しようとすると、 'model2.fit_generator()'はValueErrorを発生させます:Tensor "non_trainable1"はこのグラフの要素ではありません。 – Manas

+0

なぜ発電機の中で??そして、あなたはジェネレータの中で正確に何をしようとしていますか?予測を使用する代わりにテンソルを使用しているように聞こえます。 –

+0

私は画像で作業しており、データ拡大にジェネレータを使用しています。素朴な最初の試みとして、 'preprocessing_function'引数(ref [docs](https://keras.io/preprocessing/image/))として' foo'を渡そうとしました。 'foo'はイメージをとり、' model1.predict() 'を呼び出し、編集されたイメージを返します。 – Manas

関連する問題