2017-03-02 10 views
0

私は画像分類のためのCNNを構築しました。トレーニング中にいくつかのチェックポイントを保存しました。データはfeed_dictionaryを通じてネットワークに送られます。グラフの復元中にTensorflowがfeed_dictが足りないと文句を言う

今、私は失敗したモデルを復元したいのですが、その理由を理解できません。次のようにコードの重要な行は、次のとおりです。

トレースバック(最新の呼び出しの最後):: ファイル「C:\プログラムファイル\ Anaconda3 \

with tf.Graph().as_default(): 

.... 

if checkpoint_dir is not None: 
    checkpoint_saver = tf.train.Saver() 
    session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir, 
                 save_secs=flags.save_interval_secs, 
                 saver=checkpoint_saver)) 
.... 

with tf.train.MonitoredTrainingSession(
     save_summaries_steps=flags.save_summaries_steps, 
     hooks=session_hooks, 
     config=tf.ConfigProto(
      log_device_placement=flags.log_device_placement)) as mon_sess: 

    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 
    if checkpoint and checkpoint.model_checkpoint_path: 

     # restoring from the checkpoint file 
     checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path) 

     global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1] 
     print("Model restored from checkpoint: global_step = %s" % global_step_restore) 

ライン "checkpoint_saver.restoreは、" エラーがスローされます_do_callの1022行目、 戻り値fn(* args) ファイル "C:¥Program Files¥Anaconda3¥envs¥tensorflow¥lib¥envs¥tensorflow¥lib¥site-packages¥tensorflow¥python¥client¥session.py"行1004、_run_fn ステータス、run_metadata) ファイル "C:¥Program Files¥Anaconda3¥envs¥tensorflow¥lib¥contextlib.py"、6行目、サイトパッケージ\テンソルフロー\ python \ client \ session.py " 6、in exit 次(self.gen) raise_exception_on_not_ok_statusでファイル "C:¥Program Files¥Anaconda3¥envs¥tensorflow¥lib¥site-packages¥tensorflow¥python¥framework¥errors_impl.py" 469行目 pywrap_tensorflow.TF_GetCode(ステータス)) tensorflow.python.framework.errors_impl.InvalidArgumentError:input_images = Placeholderdtype = DT_FLOAT、形状= []、_device:あなたはDTYPEフロート [ノードとプレースホルダテンソル 'input_images' の値を供給しなければなりません= "/ job:localhost/replica:0/task:0/cpu:0"]]

これを解決する方法はありますか?なぜ私はグラフを復元するために充填されたfeed_dictionaryが必要なのですか?

ありがとうございます!

アップデート:私は得ることはありません何

def restore(self, sess, save_path): 
    """Restores previously saved variables. 

    This method runs the ops added by the constructor for restoring variables. 
    It requires a session in which the graph was launched. The variables to 
    restore do not have to have been initialized, as restoring is itself a way 
    to initialize variables. 

    The `save_path` argument is typically a value previously returned from a 
    `save()` call, or a call to `latest_checkpoint()`. 

    Args: 
     sess: A `Session` to use to restore the parameters. 
     save_path: Path where parameters were previously saved. 
    """ 
    if self._is_empty: 
     return 
    sess.run(self.saver_def.restore_op_name, 
      {self.saver_def.filename_tensor_name: save_path}) 

:なぜグラフがすぐに実行され

これは、セーバーオブジェクトの復元方法のコードですか?間違った方法を使っていますか?私はすべての訓練可能な瓶を元に戻したい。

+0

すべての変数とプレースホルダに名前を付けます。これは役に立ちましたか? http://stackoverflow.com/questions/34793978/tensorflow-complaining-about-placeholder-after-model-restore – hars

+0

すべてのvarsの名前が付けられます。画像テンソルの入力フィードがありません。私は、この問題は、MonitoredTrainingSessionとfeed_dictを組み合わせて使用​​したことが原因だと考えています。 MonitoredTrainingSessionは、より大きな設定に使用されることを意図しており、フィード辞書と互換性がない可能性があります。カスタム "トレーニングフレームワーク"のテストケースを構築しようとしています。そのため、私はモデルの軽量化を図り続けたい(インポートキューではなく、feed_dictを使用する) – monchi

答えて

1

問題は、プロセスログにSessionRunHookによって原因だった:

オリジナルフック:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def begin(self): 
    self._step = -1 

    def before_run(self, run_context): 
    self._step += 1 
    self._start_time = time.time() 
    return tf.train.SessionRunArgs(loss) # Asks for loss value. 

    def after_run(self, run_context, run_values): 
    duration = time.time() - self._start_time 
    loss_value = run_values.results 
    if self._step % 5 == 0: 
     num_examples_per_step = FLAGS.batch_size 
     examples_per_sec = num_examples_per_step/duration 
     sec_per_batch = float(duration) 

     format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
        'sec/batch)') 
     print (format_str % (datetime.now(), self._step, loss_value, 
          examples_per_sec, sec_per_batch)) 

修正フック:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def __init__(self, flags, loss_op): 
     self._flags = flags 
     self._loss_op = loss_op 
     self._start_time = time.time() 

    def begin(self): 
     self._step = 0 

    def before_run(self, run_context): 
     if self._step == 0: 
      run_args = None 
     else: 
      run_args = tf.train.SessionRunArgs(self._loss_op) 

     return run_args 

    def after_run(self, run_context, run_values): 

     if self._step > 0: 
      duration_n_steps = time.time() - self._start_time 
      loss_value = run_values.results 
      if self._step % self._flags.log_every_n_steps == 0: 
       num_examples_per_step = self._flags.batch_size 

       duration = duration_n_steps/self._flags.log_every_n_steps 
       examples_per_sec = num_examples_per_step/duration 
       sec_per_batch = float(duration) 

       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
           'sec/batch)') 
       print(format_str % (datetime.now(), self._step, loss_value, 
            examples_per_sec, sec_per_batch)) 

       self._start_time = time.time() 
     self._step += 1 

説明:

ロギングが今skipedです最初の反復のために。したがって、Saver.restore(..)によって実行されるsession.runは、埋められたフィード辞書をもうrequiereしません。

関連する問題