2017-10-20 10 views
0

tf.train.MonitoredTrainingSessionを使ってCNNをトレーニングするコードがあります。Tensorflowは `tf.Session`を使って` tf.train.MonitoredTrainingSession`を使ってチェックポイントを保存しました

新しいtf.train.MonitoredTrainingSessionを作成すると、入力パラメータとしてcheckpointディレクトリをセッションに渡すことができ、最新の保存されたcheckpointが自動的に復元されます。そして、あるステップまで訓練するようにhooksを設定することができます。たとえば、checkpointのステップが150,000で、200,000までトレーニングしたい場合は、last_step200,000に設定します。

tf.train.MonitoredTrainingSessionを使用して最新のcheckpointが保存されている限り、上記のプロセスは完全に機能します。しかし、tf.Sessionを使用して保存されたcheckpointを復元しようとすると、すべての地獄が緩んでしまいます。グラフ内にいくつかのキーが見つかりません。

訓練はこれで行われている:checkpoint_dir属性はありませんチェックポイントのフォルダを持っていた場合

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir, 
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps), 
      tf.train.NanTensorHook(loss), 
      _LoggerHook()], 
    config=tf.ConfigProto(
     log_device_placement=FLAGS.log_device_placement)) as mon_sess: 
    while not mon_sess.should_stop(): 
    mon_sess.run(train_op) 

、これはすべての上を開始します。以前のトレーニングセッションで保存されたcheckpointがあった場合は、最新のcheckpointが復元され、トレーニングが継続されます。

今、私は最新のcheckpointを復元し、いくつかの変数を変更し、それらを保存しています:あなたが見ることができるように、ちょうどsaver.save...前に私は、ネットワーク内のすべての畳み込み層を剪定しています

saver = tf.train.Saver(variables_to_restore) 

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 

with tf.Session() as sess: 
    if ckpt and ckpt.model_checkpoint_path: 
    # Restores from checkpoint 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    print(ckpt.model_checkpoint_path) 
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps 
    else: 
    print('No checkpoint file found') 
    return 

    prune_convs(sess) 
    saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step) 

。どのように、なぜそれが行われたのかを記述する必要はありません。要点は、ネットワークが実際に変更されていることです。次に、ネットワークをcheckpointに保存します。

ここで、保存された変更されたネットワークにテストをデプロイすると、テストは正常に動作します。私は救われたcheckpointtf.train.MonitoredTrainingSessionを実行しようとすると、しかし、それは言う:

キーCONV1/weight_loss/avgのチェックポイントも

では見られないが、私が気づいたことcheckpointことtf.Sessionで、私はこの仕事を作る方法を、私は間違ったことをやっている任意の提案を知ってtf.train.MonitoredTrainingSession

で保存されたcheckpointのサイズの半分を持って保存されましたか?

答えて

0

私はそれを理解しました。明らかに、tf.Saverは、すべての変数をcheckpointから復元しません。私は直ちに復元と保存を試み、出力は半分の大きさでした。

tf.train.list_variablesを使用して、最新のcheckpointからすべての変数を取得し、tf.Variableに変換し、dictを作成しました。その後、dicttf.Saverに渡し、すべての変数を復元しました。

次に、すべての変数をinitializeにして、重みを変更しました。

これで機能しています。

関連する問題