私はRecurrent NNを使用して分類モデルを生成しています。以下はkeras
(このblogへの礼儀)を使用し、簡単なRNNモデルです:単純な列車/試験を列車/試験に変換してバッチを処理する
tr_C, ts_C, tr_r, ts_r = train_test_split(C, r, train_size=.8)
tr_C = tr_C.reshape((tr_C.shape[0], tr_C.shape[1], 1))
ts_C = ts_C.reshape((ts_C.shape[0], ts_C.shape[1], 1))
tr_r = tr_r.reshape(tr_r.shape[0], 1)
ts_r = ts_r.reshape(ts_r.shape[0], 1)
print('Build model...')
model = Sequential()
model.add(LSTM(16,
batch_input_shape=(1, 1, 1),
activation='tanh',
recurrent_activation='tanh',
kernel_initializer='he_uniform',
bias_initializer='he_uniform',
# dropout=0.1,
# recurrent_dropout=0.1,
stateful=True,
return_sequences=False,
)
)
model.add(Dense(1, activation='tanh'))
model.compile(loss='mean_absolute_error', optimizer='adam', metrics=['accuracy'])
for epoch in range(10):
print ('epoch(%s): ' % epoch)
mean_tr_acc = []
mean_tr_loss = []
for i in range(len(tr_C)):
for j in range(max_len):
tr_loss, tr_acc = model.train_on_batch(np.expand_dims(np.expand_dims(tr_C[i][j], axis=1), axis=1), tr_r[i])
mean_tr_acc.append(tr_acc)
mean_tr_loss.append(tr_loss)
model.reset_states()
print('training_acc = {}'.format(np.mean(mean_tr_acc)))
print('loss training = {}'.format(np.mean(mean_tr_loss)))
print('___________________________________')
mean_te_acc = []
mean_te_loss = []
for i in range(len(ts_C)):
for j in range(max_len):
te_loss, te_acc = model.test_on_batch(np.expand_dims(np.expand_dims(ts_C[i][j], axis=1), axis=1), ts_r[i])
mean_te_acc.append(te_acc)
mean_te_loss.append(te_loss)
model.reset_states()
for j in range(max_len):
y_pred = model.predict_on_batch(np.expand_dims(np.expand_dims(ts_C[i][j], axis=1), axis=1))
model.reset_states()
print('testing_acc = {}'.format(np.mean(mean_te_acc)))
print('loss testing = {}'.format(np.mean(mean_te_loss)))
print('___________________________________')
これは、モデルの概要です:
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (1, 16) 1152
_________________________________________________________________
dense_1 (Dense) (1, 1) 17
=================================================================
Total params: 1,169
Trainable params: 1,169
Non-trainable params: 0
_________________________________________________________________
代わりの時点で単一のサンプルを訓練し、することができます私はサンプルのバッチを訓練するためにこのコードを修正しますか?もしそうなら、どうですか?ご意見ありがとうございます。 ありがとうございます。
モデルの実装方法を教えてください。あなたは私たちにmodel.summary()を見せてもらえますか? –
@DvirSamuel Right!私はそれを追加することを忘れました。 Postは更新されました – Medo