2017-12-29 10 views
1

私の目的は単純明快です:グラフが一部変更された後、以前のログのチェックポイントファイルから変わらない変数/パラメータを復元する方法(より良いMonitoredTrainingSessionを使用して)テンソルフローでグラフが変更されたとき、 "MonitoredTrainingSession"を使用してチェックポイントの一部のみを復元する方法は?

私はここからのコードのテストを行います? https://github.com/tensorflow/models/tree/master/research/resnet resnet_model.pyで

、ライン116-118、オリジナルコード(またはグラフ)である:最初の訓練の後

with tf.variable_scope('logit'): 
    logits = self._fully_connected(x, self.hps.num_classes) 
    self.predictions = tf.nn.softmax(logits) 
with tf.variable_scope('costs'): 
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=self.labels) 
    self.cost = tf.reduce_mean(xent, name='xent') 
    self.cost += self._decay() 

、Iは、チェックポイントファイルを取得します。その後、私は最初のトレーニングで得られたチェックポイントを復元するためにlatested APIのtf.train.MonitoredTrainingSessionを使用しようと

with tf.variable_scope('logit_modified'): 
    logits_modified = self._fully_connected('fc_1',x, 48) 
    #self.predictions = tf.nn.softmax(logits)  
with tf.variable_scope('logit_2'): 
    logits_2 = self._fully_connected('fc_2', logits_modified, 
    self.hps.num_classes) 
    self.predictions = tf.nn.softmax(logits_2) 
with tf.variable_scope('costs'): 
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits_2, labels=self.labels) 
    self.cost = tf.reduce_mean(xent, name='xent') 
    self.cost += self._decay() 

: は、それから私はにコードを修正しました。私はこれを行うために複数の方法を試しましたが、どれも動作しません。

1を試してみてください: 私はMonitoredTrainingSessionで足場を使用しない場合:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root, 
    #scaffold=scaffold, 
    hooks=[logging_hook, _LearningRateSetterHook()], 
    chief_only_hooks=[summary_hook], 
    save_checkpoint_secs = 600, 
    # Since we provide a SummarySaverHook, we need to disable default 
    # SummarySaverHook. To do that we set save_summaries_steps to 0. 
    save_summaries_steps=None, 
    save_summaries_secs=None, 
    config=tf.ConfigProto(allow_soft_placement=True), 
    stop_grace_period_secs=120, 
    log_step_count_steps=100) as mon_sess: 
while not mon_sess.should_stop(): 
    mon_sess.run(_train_op) 

エラーメッセージは次のとおりです。

2017-12-29 10:33:30.699061: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key logit_modified/fc_1/biases/Momentum not found in checkpoint ...

セッションtrysが変更に応じて復元すると思われますが新しいグラフと以前のチェックポイントファイルの両方に存在する変数(つまり、すべてのレイヤーが最終2を除外しています)。

2を試してみてください: ここtf.train.Supervisorを使用して転移学習コードに触発さ: https://github.com/kwotsin/transfer_learning_tutorial/blob/master/train_flowers.py、ラインから251

まず私は、この行を追加し、resnet_model.pyにコードを変更:

self.variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=["logit_modified", "logit_2"]) 

その後MonitoredTrainingSessionで足場がに変更されています。残念ながら

saver = tf.train.Saver(variables_to_restore) 
def restore_fn(sess): 
    return saver.restore(sess, FLAGS.log_root) 
scaffold = tf.train.Scaffold(saver=saver, init_fn = restore_fn) 

followiingエラーメッセージが示された:

RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: at 0x7f0ec26f4320>, error: Variables not initialized: logit_modified/fc_1/DW, ...

、残りの層が復元されないように、適切に復元されていない最後の2層のように思えます。

お試しください3: また、こちらの方法を試しましたが、How to use tf.train.MonitoredTrainingSession to restore only certain variablesのいずれも機能しません。

https://github.com/tensorflow/models/blob/6fb14a790c283a922119b19632e3f7b8e5c0a729/research/inception/inception/inception_model.pyにコードなどの復元方法があることは知っていますが、ネストされており、他のモデルに簡単に適用できるほど一般的ではありません。これが私が "MonitoredTrainingSession"を使いたい理由です。

"MonitoredTrainingSession"を使用してテンソルフローのチェックポイントの一部のみを復元するにはどうすればよいですか?

+0

ごめんことエラーメッセージは、常にStackOverflowにインデントの問題があります。代わりに、エラーメッセージの主なアイデアが質問に示されています。 – GhostPotato

答えて

0

最後に、私はそれを理解します。

monitored_sessionを読み取った後。ここでは https://github.com/tensorflow/tensorflow/blob/4806cb0646bd21f713722bd97c0d0262c575f7e0/tensorflow/python/training/monitored_session.py、私は新しい空のチェックポイントディレクトリに変更することが重要であることを発見しました。そのため、MonitoredTrainingSessionはinit_opまたはinit_fnを無視しません。 そして、あなたは(チェックポイントを回復するために)あなたのinit_fnを構築するために、次のコードを使用するだけでなく、足場ができます。

variables_to_restore = tf.contrib.framework.get_variables_to_restore(
    exclude=['XXX'])  
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
    ckpt.model_checkpoint_path, variables_to_restore) 
def InitAssignFn(scaffold,sess): 
    sess.run(init_assign_op, init_feed_dict) 

scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn) 

は、上記ckpt.model_checkpoint_pathはそれでpretrainedファイルを使用して、古いチェックポイントのパスであることを忘れないでください。私は上記している新しい空のチェックポイントディレクトリはここMonitoredTrainingSessionのパラメータ「checkpoint_dir」を意味します

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root_2,...) as mon_sess: 
while not mon_sess.should_stop(): 
    mon_sess.run(_train_op) 

私が変更されたコードの最初の段落は、ライン134から、tf.slimにlearning.pyから派生: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py

プラス:インスピレーションのための おかげで、このQ & A、解決策は少し異なるが: What's the recommend way of restoring only parts model in distributed tensorflow

関連する問題