2017-03-01 21 views
3

私の現在のプロジェクトでは、100回の繰り返しステップごとにモデルをトレーニングし、チェックポイントを保存します。チェックポイントファイルはすべて同じディレクトリ(model.ckpt-100、model.ckpt-200、model.ckpt-300など)に保存されます。その後、最新のものだけでなく、保存されたすべてのチェックポイントの検証データに基づいてモデルを評価したいと思います。テンソルフロー:複数のチェックポイントでモデル評価を実行

現在、チェックポイントファイルを復元するためのコードの私の作品は、次のようになります。

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 
ckpt_list = saver.last_checkpoints 
print(ckpt_list) 
if ckpt and ckpt.model_checkpoint_path: 
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path) 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    # extract global_step from it. 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 
else: 
    print('No checkpoint file found') 
    return 

しかし、これは唯一、最新の保存されたチェックポイントファイルを復元します。では、保存されたすべてのチェックポイントファイルにループを書き込むにはどうすればよいですか? saver.last_checkpointsでチェックポイントファイルのリストを取得しようとしましたが、返されるリストは空です。

ご協力いただきまして、ありがとうございます。

+0

モデルを正確に保存していますか? 'saver.save(..)'を呼び出す際に、出力ファイルの名前を自分で作成しているのでしょうか、あるいは 'global_step'パラメータを使用していますか? – kaufmanu

答えて

1

あなたはディレクトリ内のファイルを反復処理することができます:and 'meta' not in fので、そのディレクトリに何があるかに応じて、オンとセーバーバージョンあなた:

import os 

dir_path = './' #change that to wherever your files are 
ckpt_files = [f for f in os.listdir(dir_path) if os.path.isfile(
    os.path.join(dir_path, f)) and 'ckpt' in f] 

for ckpt_file in ckpt_files: 
    saver.restore(sess, dir_path + ckpt_file) 
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 
    print('Succesfully loaded model from %s at step=%s.' % 
      (ckpt.model_checkpoint_path, global_step)) 

    # Do your thing 

のような、より選択的であることが上記のリスト内包でさらに条件を追加have

0

ありがとうございます。

CONV2 /バイアスが変数のスコープです:しかし、私はエラー

「キーCONV2 /偏見/ ExponentialMovingAverageチェックポイントでは見られないNotFoundErrorを(トレースバックについては上記参照)」を得ます。私は節約バージョンv2を使用します。

一方私は異なる(もう少し簡単なコード)を試みたと同じエラーを得た:

fileBaseName = FLAGS.checkpoint_dir + '/model.ckpt-' 

    for global_step in range(0,100,10): # range over the global steps where checkpoints were saved 
    x_str = str(global_step) 
    fileName = fileBaseName+x_str 
    print(fileName) 
    ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 

    #restore checkpoint file 
    saver.restore(sess, fileName) 

エラーが実際に(variables_to_restoreで=)コードのこの部分で発生:

# Restore the moving average version of the learned variables for eval. 
variable_averages = tf.train.ExponentialMovingAverage(
    MOVING_AVERAGE_DECAY) 
variables_to_restore = variable_averages.variables_to_restore() 
saver = tf.train.Saver(variables_to_restore) 

私はこのエラーを解決する方法を知りません。それはセーバー版と何か関係がありますか?または、チェックポイントが保存されている部分のエラーである必要がありますか?

ありがとうございます。 TheJude

関連する問題