私は出力を生成するModel
と出力をグレードするAdversary
の2つのネットワークを持っています。同じTensorflowセッションでセーバーから2つのモデルをロード
両方とも個別に訓練されていますが、今は1回のセッションで出力を結合する必要があります。
私はこの記事で提案されたソリューションを実装しようとしました:Run multiple pre-trained Tensorflow nets at the same time
私のコード
with tf.name_scope("model"):
model = Model(args)
with tf.name_scope("adv"):
adversary = Adversary(adv_args)
#...
with tf.Session() as sess:
tf.global_variables_initializer().run()
# Get the variables specific to the `Model`
# Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
model_varlist = {v.name.lstrip("model/")[:-2]: v
for v in tf.global_variables() if v.name[:5] == "model"}
model_saver = tf.train.Saver(var_list=model_varlist)
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# Get the variables specific to the `Adversary`
adv_varlist = {v.name.lstrip("avd/")[:-2]: v
for v in tf.global_variables() if v.name[:3] == "adv"}
adv_saver = tf.train.Saver(var_list=adv_varlist)
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)
機能model_saver.restore()
への呼び出しがあるように思われる
問題何もしない。別のモジュールでは、tf.train.Saver(tf.global_variables())
のセーバーを使用し、チェックポイントを正常に復元します。
モデルはmodel.tvars = tf.trainable_variables()
です。何が起きているのかを確認するために、sess.run()
を使用してリストアの前後にtvars
を抽出しました。初期にランダムに割り当てられた変数が使用され、チェックポイントからの変数が割り当てられないたびに
なぜmodel_saver.restore()
に何もしていないようだと思っていますか?
2番目のモデルへの入力として最初のモデルの出力を使用できますか? – Sentient07