2017-10-01 34 views
1

テンソルフローバックエンドでケラスを使用しています。各タイムステップでのLSTMモデルの隠れ状態ベクトルの抽出

私は、各タイムステップでを非表示状態ベクトルIとして抽出したい訓練されたLSTMモデルを持っています

ケラスでこれを行うにはどうすればよいですか?

+0

LSTMレイヤのパラメータと使用しているKerasのバージョンは何ですか? (LSTMの文献では普通は 'c 'と表記されています)、隠れた状態のベクトル(普通は' h'とも呼ばれ、レイヤーの出力でもあります)です。後者ははるかに簡単です。 –

+0

私は3つのスタックされた双方向LSTMを持っています。私は最上位層の隠れた状態ベクトルを必要とします(私はセル状態を必要としません)。私はkeras 2.0.6を使用しています – Kiran

答えて

1

すべての隠し状態ベクトルが返されるかどうかを処理する関数はRecurrent.call()です(最新バージョンではRNN.call()に名前が変更されています)。パラメータreturn_sequencesをチェックして決定します。バックエンド機能K.rnn()は、この関数で呼び出され

last_output, outputs, states = K.rnn(self.step, 
            preprocessed_input, 
            initial_state, 
            go_backwards=self.go_backwards, 
            mask=mask, 
            constants=constants, 
            unroll=self.unroll, 
            input_length=input_shape[1]) 

... 

if self.return_sequences: 
    output = outputs 
else: 
    output = last_output 

テンソルoutputsは、あなたが望むものです。 Recurrent.call()をもう一度呼び出すと、return_sequences=Trueでこのテンソルを得ることができます。これはあなたの訓練されたLSTMモデル(少なくとも現在のKerasでは)に害を及ぼさないはずです。


は、ここでは、この方法を実証おもちゃバイLSTMモデルです:

input_tensor = Input(shape=(None,), dtype='int32') 
embedding = Embedding(10, 100, mask_zero=True)(input_tensor) 
hidden = Bidirectional(LSTM(10, return_sequences=True))(embedding) 
hidden = Bidirectional(LSTM(10, return_sequences=True))(hidden) 
hidden = Bidirectional(LSTM(2))(hidden) 
out = Dense(1, activation='sigmoid')(hidden) 
model = Model(input_tensor, out) 

まず、あなたがBidirectionalラッパーを使用していることから、あなたが設定する必要が(最後LSTM層のためのTruereturn_sequencesを設定forward_layerともbackward_layer):

target_layer = model.layers[-2] 
target_layer.return_sequences = True 
target_layer.forward_layer.return_sequences = True 
target_layer.backward_layer.return_sequences = True 

今、もう一度、この層を呼び出すことによって、テンソルcontai隠れたベクトルは常に返されます(追加のインバウンドノードを作成するという副作用がありますが、予測には影響しません)。

たとえば、m.predict(X_test)を呼び出して隠れベクターを取得できます。

X_test = np.array([[1, 3, 2, 0, 0]]) 
print(m.predict(X_test)) 

[[[ 0.00113332 -0.0006666 0.00428438 -0.00125567] 
    [ 0.00106074 -0.00041183 0.00383953 -0.00027285] 
    [ 0.00080892 0.00027685 0.00238486 0.00036328] 
    [ 0.00080892 0.00027685 0.   0.  ] 
    [ 0.00080892 0.00027685 0.   0.  ]]] 

ご覧のとおり、5つのタイムステップすべての隠れたベクトルが返され、最後の2つのタイムステップが適切にマスクされます。

+0

'K.function'を使うよりも理解しやすい' predict() 'を使うように編集しました。 –

関連する問題