tf.train.MonitoredTrainingSession
を使ってCNNをトレーニングするコードがあります。Tensorflowは `tf.Session`を使って` tf.train.MonitoredTrainingSession`を使ってチェックポイントを保存しました
新しいtf.train.MonitoredTrainingSession
を作成すると、入力パラメータとしてcheckpoint
ディレクトリをセッションに渡すことができ、最新の保存されたcheckpoint
が自動的に復元されます。そして、あるステップまで訓練するようにhooks
を設定することができます。たとえば、checkpoint
のステップが150,000
で、200,000
までトレーニングしたい場合は、last_step
を200,000
に設定します。
tf.train.MonitoredTrainingSession
を使用して最新のcheckpoint
が保存されている限り、上記のプロセスは完全に機能します。しかし、tf.Session
を使用して保存されたcheckpoint
を復元しようとすると、すべての地獄が緩んでしまいます。グラフ内にいくつかのキーが見つかりません。
訓練はこれで行われている:checkpoint_dir
属性はありませんチェックポイントのフォルダを持っていた場合
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
、これはすべての上を開始します。以前のトレーニングセッションで保存されたcheckpoint
があった場合は、最新のcheckpoint
が復元され、トレーニングが継続されます。
今、私は最新のcheckpoint
を復元し、いくつかの変数を変更し、それらを保存しています:あなたが見ることができるように、ちょうどsaver.save...
前に私は、ネットワーク内のすべての畳み込み層を剪定しています
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
。どのように、なぜそれが行われたのかを記述する必要はありません。要点は、ネットワークが実際に変更されていることです。次に、ネットワークをcheckpoint
に保存します。
ここで、保存された変更されたネットワークにテストをデプロイすると、テストは正常に動作します。私は救われたcheckpoint
にtf.train.MonitoredTrainingSession
を実行しようとすると、しかし、それは言う:
キーCONV1/weight_loss/avgのチェックポイントも
では見られないが、私が気づいたことcheckpoint
ことtf.Session
で、私はこの仕事を作る方法を、私は間違ったことをやっている任意の提案を知ってtf.train.MonitoredTrainingSession
で保存されたcheckpoint
のサイズの半分を持って保存されましたか?