私はtensorflow1.1でESPCNを訓練しましたが、訓練の際に1パッチあたりのコストがほぼ直線的に増加しました。最初の100エポックは4〜5秒しかかかりませんが、70エポックは約30分かかります。以下のトレーニング結果を参照してください:トレーニングの際にtensorflow1.1がなぜ遅くなり、遅くなるのですか?メモリリークか待ち行列の飢餓ですか?
私はGoogleとスタックオーバーフローで同じ質問を検索し、以下のソリューションを試してみましたが、仕事全く思えませんしました: は、すべてのsess.run()
後tf.reset_default_graph()
を1.add ; 2.キューの枯渇を防ぐためにtime.sleep(5)
を追加します。
私は一般的な考え方を知っています。つまり、Session()の操作を減らすことです。しかしどうですか?誰にでも解決策がありますか?ここで
は、私のコードの一部です:
L3, var_w_list, var_b_list = model_train(IN, FLAGS)
cost = tf.reduce_mean(tf.reduce_sum(tf.square(OUT - L3), reduction_indices=0))
global_step = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(FLAGS.base_lr, global_step * FLAGS.batch_size, FLAGS.decay_step, 0.96, staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step = global_step, var_list = var_w_list + var_b_list)
# optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9).minimize(cost, var_list = var_w_list + var_b_list)
cnt = 0
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
print('\n\n\n =========== All initialization finished, now training begins ===========\n\n\n')
t_start = time.time()
t1 = t_start
for i in range(1, FLAGS.max_Epoch + 1):
LR_batch, HR_batch = batch.__next__()
global_step += 1
[_, cost1] = sess.run([optimizer, cost], feed_dict = {IN: LR_batch, OUT: HR_batch})
# tf.reset_default_graph()
if i % 100 == 0 or i == 1:
print_step = i
print_loss = cost1/FLAGS.batch_size
test_LR_batch, test_HR_batch = test_batch.__next__()
test_SR_batch = test_HR_batch.copy()
test_SR_batch[:,:,:,0:3] = sess.run(L3, feed_dict = {IN: test_LR_batch[:,:,:,0:3]})
# tf.reset_default_graph()
psnr_tmp = 0.0
ssim_tmp = 0.0
for k in range(test_SR_batch.shape[0]):
com1 = test_SR_batch[k, :, :, 0]
com2 = test_HR_batch[k, :, :, 0]
psnr_tmp += get_psnr(com1, com2, FLAGS.HR_size, FLAGS.HR_size)
ssim_tmp += get_ssim(com1, com2, FLAGS.HR_size, FLAGS.HR_size)
psnr[cnt] = psnr_tmp/test_SR_batch.shape[0]
ssim[cnt] = ssim_tmp/test_SR_batch.shape[0]
ep[cnt] = print_step
t2 = time.time()
print_time = t2 - t1
t1 = t2
print(("[Epoch] : {0:d} [Current cost] : {1:5.8f} \t [Validation PSNR] : {2:5.8f} \t [Duration time] : {3:10.8f} s \n").format(print_step, print_loss, psnr[cnt], print_time))
# tf.reset_default_graph()
cnt += 1
if i % 1000 == 0:
L3_test = model_test(IN_TEST, var_w_list, var_b_list, FLAGS)
output_img = single_HR.copy()
output_img[:,:,:,0:3] = sess.run(L3_test, feed_dict = {IN_TEST:single_LR[:,:,:,0:3]})
tf.reset_default_graph()
subname = FLAGS.img_save_dir + '/' + str(i) + ".jpg"
img_gen(output_img[0,:,:,:], subname)
print(('================= Saving model to {}/model.ckpt ================= \n').format(FLAGS.checkpoint_dir))
time.sleep(5)
# saver.save(sess, FLAGS.checkpoint_dir + '/model.ckpt', print_step)
t_tmp = time.time() - t_start
私の構成は次のとおりです。windows10 + tf1.1 + python3.5 + cuda8.0 + cudnn5.1
======== ========================================== ======
さらに、最後のレイヤーでデコンボリューションの代わりにピクセルシャッフル(PS)レイヤーを使用しました。私は、以下に示す他のPSコード、コピー:
def _phase_shift(I, r):
bsize, a, b, c = I.get_shape().as_list()
bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(X, a, 1) # a, [bsize, b, r, r]
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2) # bsize, b, a*r, r
X = tf.split(X, b, 1) # b, [bsize, a*r, r]
X = tf.concat([tf.squeeze(x, axis=1) for x in X], 2) # bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
def PS(X, r, color=False):
if color:
Xc = tf.split(X, 3, 3)
X = tf.concat([_phase_shift(x, r) for x in Xc], 3)
else:
X = _phase_shift(X, r)
return X
X
4次元画像テンソルであり、r
がアップスケーリング因子を意味し、color
は、画像のチャンネル3(YCbCrフォーマットであるかどうかを決定します)または1(グレースケール形式)。デコンボリューション層を使用したときに、プログラムがうまくいくため
L3_ps = PS(L3, scale, True)
は、今私は、この層は減速を引き起こしたかどうかを思ったんだけど:レイヤーを使用するには
はtf.nn.relu()
がするよう、非常に簡単です。デコンボリューション層を使うのは解決策かもしれませんが、何らかの理由でPS層を使わなければなりません。
sess.runコールの時間を計ることはできますか?それはtfランタイムを使用する唯一の部分であり、どこか他の場所では速度が遅くなる可能性があります。あなたがtfを使用する方法は、ループの間にグラフの成長を引き起こさないように見えます(遅さにつながります)。 –
私はあなたの提案を試しました。 '[_、cost1] = sess.run([オプティマイザ、コスト]、feed_dict = {IN:LR_batch、OUT:HR_batch})ごとにタイミングを取っています。すべての時代が減速するのを見ました。私はPS層が質問の記述の後に追加された減速を引き起こしたかどうか疑問に思っています。@AlexandrePassos – Cyiano
私は今、この正確な問題を観察しています。あなたはそれを解決しましたか? –