2017-04-05 24 views
4

だから私たちは基本的にkerasステートフルLSTM

Y[i] = X[i-1] 

を持っており、モデルは単に遅れ演算子であり、この単純な例

nb_samples = 100000 
X = np.random.randn(nb_samples) 
Y = X[1:] 
X = X[:-1] 
X = X.reshape((len(Y), 1, 1)) 
Y = Y.reshape((len(Y), 1)) 

をご検討ください。

私はこのモデルをステートレスLSTMで学ぶことができますが、ここではKerasでステートフルなLSTMを理解して適用したいと思います。

だから私は、値

model = Sequential() 
model.add(LSTM(batch_input_shape=(1, 1, 1), 
       output_dim =10, 
       activation='tanh', stateful=True 
     ) 
    ) 
model.add(Dense(output_dim=1, activation='linear')) 
model.compile(loss='mse', optimizer='adam') 


for epoch in range(50): 
    model.fit(X_train, 
       Y_train, 
       nb_epoch = 1, 
       verbose = 2, 
       batch_size = 1, 
       shuffle = False) 
    model.reset_states() 

ずつのペアを与えることによって、ステートフルLSTMで、このモデルを学習しようとする。しかしこのモデルは何かを学ぶことはありません。

for epoch in range(10000): 
    model.reset_states() 
    train_loss = 0 
    for i in range(Y_train.shape[0]): 
     train_loss += model.train_on_batch(X_train[i:i+1], 
         Y_train[i:i+1], 
         ) 
    print '# epoch', epoch, ' loss ', train_loss/float(Y_train.shape[0]) 

を私はまだ私のランダムに生成されたデータの標準偏差である1の周りの平均損失を、見ていたモデルはないので、次のように

は、マルチンの提案を1として、私はトレーニングのコードを修正しました学ばないようです。

私は間違っていますか?

+1

ユニット数を増やしてみましたか?あなたのネットワークは長さ10000の完全にランダムなパターンをメモする必要があります。実際、シーケンスとして完全にメモに記入する必要があります。 '10'単位ではこれでは不十分かもしれません。シークエンスの長さを短くしたり、いくつかの連続関数( 'sin'や多項式など)を調べようとすることもできます。現時点では、あなたのアーキテクチャーはあなたの仕事にとってシンプルであるようです。 –

+0

揮発性LSTMは、しばしばランニングアベレージを学習します。 'randn'の実行平均は0になります。これがあなたの出力であれば、学習は実際に成功しました。有意義なものを学ぶことを試みてください。 – nemo

+0

@nemo:答えに感謝します。私は長さ10000のランダムパターンを学習するのではなく、tの出力をt-1の入力とするパターンを学習しています。私は、入力x [t]を隠れた状態にしてt + 1に置き、その隠れた状態を出力y [t + 1]として返し、隠れた状態をx [ t + 1]であり、再帰的にそれを行う。明らかに、私はムービングウィンドウ付きのステートレスLSTMを使用してこれを達成できますが、ステートフルなもので結果を得たいと思っています – volatile

答えて

1

モデルの状態がネットワークのステートフルな状態(パラメータがオプティマイザであるためにリセットされない)であっても、hereと表示されているように、オプティマイザがリカレントニューラルネットワークトレーニングで非常に重要であるため、あなたのトレーニングに非常に有害な可能性があります。その試みを防ぐために:あなたのオプティマイザをリセットしません

for epoch in range(50): 
    model.train_on_batch(X_train, 
       Y_train) 
    model.reset_states() 

train_on_batch方法は、あなたのトレーニングを可能にすることができるもの述べています。

+0

ありがとうございます。私は試しましたが、うまくいかないようです。私は自分の質問を更新して、私が何をしているのかを伝える – volatile

関連する問題