2017-01-12 6 views
3

私は出力を生成するModelと出力をグレードするAdversaryの2つのネットワークを持っています。同じTensorflowセッションでセーバーから2つのモデルをロード

両方とも個別に訓練されていますが、今は1回のセッションで出力を結合する必要があります。

私はこの記事で提案されたソリューションを実装しようとしました:Run multiple pre-trained Tensorflow nets at the same time

私のコード

with tf.name_scope("model"): 
    model = Model(args) 
with tf.name_scope("adv"): 
    adversary = Adversary(adv_args) 

#... 

with tf.Session() as sess: 
    tf.global_variables_initializer().run() 

    # Get the variables specific to the `Model` 
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint 
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
        for v in tf.global_variables() if v.name[:5] == "model"} 
    model_saver = tf.train.Saver(var_list=model_varlist) 
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir) 
    model_saver.restore(sess, model_ckpt.model_checkpoint_path) 

    # Get the variables specific to the `Adversary` 
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
        for v in tf.global_variables() if v.name[:3] == "adv"} 
    adv_saver = tf.train.Saver(var_list=adv_varlist) 
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir) 
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path) 

機能model_saver.restore()への呼び出しがあるように思われる

問題何もしない。別のモジュールでは、tf.train.Saver(tf.global_variables())のセーバーを使用し、チェックポイントを正常に復元します。

モデルはmodel.tvars = tf.trainable_variables()です。何が起きているのかを確認するために、sess.run()を使用してリストアの前後にtvarsを抽出しました。初期にランダムに割り当てられた変数が使用され、チェックポイントからの変数が割り当てられないたびに

なぜmodel_saver.restore()に何もしていないようだと思っていますか?

答えて

13

この問題を解決するには時間がかかりましたので、他の誰かが必要とする可能性がある場合に備えて、不完全なソリューションを投稿しています。

問題を診断するために、私は手動で各変数をループし、1つずつ割り当てました。それから、変数を割り当てた後に名前が変わることに気付きました。これはここに記述されています:TensorFlow checkpoint save and read

私はそれぞれのモデルをそれぞれのグラフで実行しました。それはまた、私自身のセッションで各グラフを実行しなければならないことを意味します。これは、セッション管理を別の方法で処理することを意味しました。

まず私はいつでも、それぞれ次に、2つのセッション

adv_sess = tf.Session(graph=adv_graph) 
sess = tf.Session(graph=model_graph) 

は、その後、私はここから

別途
with sess.as_default(): 
    with model_graph.as_default(): 
     tf.global_variables_initializer().run() 
     model_saver = tf.train.Saver(tf.global_variables()) 
     model_ckpt = tf.train.get_checkpoint_state(args.save_dir) 
     model_saver.restore(sess, model_ckpt.model_checkpoint_path) 

with adv_sess.as_default(): 
    with adv_graph.as_default(): 
     tf.global_variables_initializer().run() 
     adv_saver = tf.train.Saver(tf.global_variables()) 
     adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir) 
     adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path) 

を各セッションで変数を初期化し、各グラフを復元二つのグラフ

model_graph = tf.Graph() 
with model_graph.as_default(): 
    model = Model(args) 

adv_graph = tf.Graph() 
with adv_graph.as_default(): 
    adversary = Adversary(adv_args) 

を作成しましたセッションが必要でした私はすべてのtfの関数をラップします控え目はwith sess.as_default():です。最後にセッションを手動で終了します

sess.close() 
adv_sess.close() 
+0

2番目のモデルへの入力として最初のモデルの出力を使用できますか? – Sentient07