私は、バイナリ画像分類の問題のためにいくつかのチュートリアルコードを実行しています。非常にシンプルなアーキテクチャ(3つの畳み込み/リル/プーリング+完全接続)が、各エポックの最後のトレーニングステップは約130秒かかりますが、最初の127は合計20秒かかります。誰もがこれを説明することができ、私はどういうわけかこれをスピードアップできますか?私は2GBのVRAMを搭載したGPUで動いています。convnetを訓練するとき、なぜ各エポックの最後のステージが長くかかるのですか?
rmsprop = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)
model.compile(loss='binary_crossentropy',
optimizer=rmsprop,
metrics=['accuracy'])
nb_epoch = 30
nb_train_samples = 2048
nb_validation_samples = 832
model.summary()
model.fit_generator(
train_generator,
samples_per_epoch=nb_train_samples,
nb_epoch=nb_epoch,
validation_data=validation_generator,
nb_val_samples=nb_validation_samples)
127/128 [============================>.] - ETA: 0s - loss: 0.7302 - acc: 0.5266
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 148, 148, 32) 896
_________________________________________________________________
activation_1 (Activation) (None, 148, 148, 32) 0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 74, 74, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 72, 72, 32) 9248
_________________________________________________________________
activation_2 (Activation) (None, 72, 72, 32) 0
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 36, 36, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 34, 34, 64) 18496
_________________________________________________________________
activation_3 (Activation) (None, 34, 34, 64) 0
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 17, 17, 64) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 18496) 0
_________________________________________________________________
dense_1 (Dense) (None, 64) 1183808
_________________________________________________________________
activation_4 (Activation) (None, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 64) 0
_________________________________________________________________
dense_2 (Dense) (None, 1) 65
_________________________________________________________________
activation_5 (Activation) (None, 1) 0
=================================================================
Total params: 1,212,513.0
Trainable params: 1,212,513.0
Non-trainable params: 0.0
_________________________________________________________________
あなたは 'validation_generator'が指定されているので、ネットワークが検証を行っています。その引数を省略すると、遅延が発生しません。 –