1
Iは示された構造と、以下の操作を行って、保存されたファイルからオートエンコーダをロードしています:Keras:最初のn層を取得
autoencoder = load_model("autoencoder_mse1.h5")
autoencoder.summary()
>>> ____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_8 (InputLayer) (None, 19) 0
____________________________________________________________________________________________________
dense_43 (Dense) (None, 16) 320 input_8[0][0]
____________________________________________________________________________________________________
dense_44 (Dense) (None, 16) 272 dense_43[0][0]
____________________________________________________________________________________________________
dense_45 (Dense) (None, 2) 34 dense_44[0][0]
____________________________________________________________________________________________________
dense_46 (Dense) (None, 16) 48 dense_45[0][0]
____________________________________________________________________________________________________
dense_47 (Dense) (None, 16) 272 dense_46[0][0]
____________________________________________________________________________________________________
dense_48 (Dense) (None, 19) 323 dense_47[0][0]
====================================================================================================
Total params: 1269
__________________
InputLayer
含む最初の4つの層は、エンコーダの一部を構成します。これらの4つのレイヤーをすばやく取得する方法があるのだろうかと思っていました。これまで私が遭遇した唯一の解決策は、
encoder = Sequential()
encoder.add(Dense(16, 19, weights=autoencoder.layers[1].get_weights()))
^です。より効率的な方法で最初の4つのレイヤーを抽出する方法があることを期待しています。特に、.summary()
メソッドがレイヤーの要約を吐き出すので。
編集1(解決策): 私は以下の解決策に到達しましたが、私はもう少し効率の良い(コードが少ない)ことを望んでいます。
encoder = Sequential()
for i,l in enumerate(autoencoder.layers[1:]):
if i==0:
encoder.add(Dense(input_dim=data.shape[1],output_dim=l.output_dim,
activation="relu",weights=l.get_weights()))
else:
encoder.add(Dense(output_dim=l.output_dim,activation="relu",weights=l.get_weights()))
if l.output_dim == 2:
break
私はあなたが何を得ようとしているかを見ていると思うが、この答えはコメントでなければならない。それは私の質問に直接答えません。私はこのメソッドをウェイトにすることができますが、私はウェイトのサイズとスタックのスタックのタイプを知る必要があります。 –
答えを編集しました。これが機能するかどうかを確認してください。 –
それは機能しましたか?私にお知らせください。 –