2017-01-10 5 views
0

私は2つのグラフを保存しています。 1つは2X2のゼロテンソルを持ち、もう1つは同じサイズのテンソルを持つものです。私は条件に応じてそれらの1つを復元します。tf.Session()はwithブロックを終了しても一見閉じられません

import tensorflow as tf 


def save_zero(): 
    # save a 2x2 variable filled with zeros 
    with tf.Graph().as_default(): 
     session = tf.Session() 
     with tf.name_scope('dummy_graph'): 
      tf.Variable([[0.0, 0.0], [0.0, 0.0]], name='a', dtype=tf.float32) 
     init_op = tf.global_variables_initializer() 
     session.run(init_op) 
     saver = tf.train.Saver() 
     saver.save(session, 'zero') 
     session.close() 


def save_one(): 
    # save a 2x2 variable filled with ones 
    with tf.Graph().as_default(): 
     session = tf.Session() 
     with tf.name_scope('dummy_graph'): 
      tf.Variable([[1.0, 1.0], [1.0, 1.0]], name='a', dtype=tf.float32) 
     init_op = tf.global_variables_initializer() 
     session.run(init_op) 
     saver = tf.train.Saver() 
     saver.save(session, 'one') 
     session.close() 


def test(boolean): 
    with tf.Session() as session: 
     if boolean: 
      saver = tf.train.import_meta_graph('one.meta') 
      saver.restore(session, './one') 

      session.run(session.graph.get_operation_by_name('init')) 
      tensor = session.graph.get_tensor_by_name('dummy_graph/a:0') 
     else: 

      saver = tf.train.import_meta_graph('zero.meta') 
      saver.restore(session, './zero') 

      session.run(session.graph.get_operation_by_name('init')) 
      tensor = session.graph.get_tensor_by_name('dummy_graph/a:0') 

     return session.run(tensor) 

save_zero() 
save_one() 
print(test(False)) 
print(test(True)) 

testの呼び出しは両方ともゼロを返します。

def test(boolean): 
    with tf.Session() as session: 
     if boolean: 
      saver = tf.train.import_meta_graph('one.meta') 
      saver.restore(session, './one') 
      # contains duplicate ops (suffixed with '_1') 
      [print(op.name) for op in session.graph.get_operations()] 
      session.run(session.graph.get_operation_by_name('init')) 
      tensor = session.graph.get_tensor_by_name('dummy_graph/a:0') 
     else: 

      saver = tf.train.import_meta_graph('zero.meta') 
      saver.restore(session, './zero') 

      session.run(session.graph.get_operation_by_name('init')) 
      tensor = session.graph.get_tensor_by_name('dummy_graph/a:0') 

     return session.run(tensor) 

このバグであるか私が行方不明です:セッションでOPSを観察するtest戻り時にtestでセッションが閉じられている私の知る限りは、セッションとして起こることを想定されていないの二つの呼び出しで再利用されていることを示唆しています何か?

+0

問題があるかどうかはわかりませんが、セッションを閉じることなく両方の保存機能でセッションを開きます。 save_one()とsave_zero()の両方の終わりにsession.close()を追加しても問題が解決するかどうか確認できますか? –

+0

@MadWombatそれは問題ではありません。とにかくありがとう。 – Priyatham

+1

どのバージョン?私は2つの 'test'コールとは異なる結果を得ます。0.12 –

答えて

2

TL; DR:tf.Sessionは、あなたのコード内でtest()に2回の呼び出しの間閉鎖されているが、2つのセッションが同じtf.Graphを共有しているので、あなたが問題に実行されています。問題を回避するには、新鮮なtf.Graphで各セッションを作成します。特に

その後test(True)への呼び出しでtf.train.import_meta_graph()を呼び出すとき、あなたはtest(False)の呼び出しでtf.train.import_meta_graph()を呼び出したときに作成したノードがグラフに残ります。つまり、session.graph.get_tensor_by_name('dummy_graph/a:0')への2回の呼び出しのそれぞれは、最初にtest()を呼び出したときに作成された同じノードを返します。

この問題を回避する方法はいくつかあります。最も単純には独自のグラフでtf.Sessionを作成することです:あなたは新しいグラフを毎回作成している、チェックポイントセーブ中に

、そう:mrry答えに詳細を追加するには

def test(boolean): 
    # Session will use its own graph. 
    with tf.Session(graph=tf.Graph()) as session: 
     if boolean: 
      # ... 
+0

ありがとう、それは私の理解に多くの助けとなりました。私は、各セッションインスタンスにデフォルトグラフがあるという印象を受けていましたが、プロセスごとにグローバルデフォルトグラフが1つしかないことが判明しました。あれは正しいですか? – Priyatham

+0

ほぼはい。実際に**スレッド**ごとに1つのグローバルデフォルトグラフが存在します。 'with graph.as_default():'ブロックを入力すると、現在のスレッドのデフォルトグラフとして 'graph'が設定され、ブロックを終了すると、以前のデフォルトにリセットされます。 'tf.Session'を作成すると、イニシャライザでオプションの' graph'引数を指定しないかぎり、グラフとして現在のデフォルトのグラフが使用されます。 – mrry

+0

[api docs](https://www.tensorflow.org/api_docs/python/framework/core_graph_data_structures#Graph.as_default)から引用すると、 'このメソッドは、同じプロセスで複数のグラフを作成する場合に使用する必要があります。 '恐らく '...同じ**スレッド** 'はあまり混乱しないでしょう、あなたはそれを修正できますか? – Priyatham

1

、ここで起こっている正確に何です両方に保存されていることテンソルdummy_graph/aチェックポイント:最初のロード中に

zero.data-00000-of-00001: dummy_graph/a - [0, 0, 
one.data-00000-of-00001: dummy_graph/a - [1, 1, 

を、次にその初期化オペアンプを呼び出し、それに負荷[0, 0,、あなたが最初の変数dummy_graph/aでグラフを作成して呼び出しますこの値を[0, 0,で上書きします。

2回目のロードコールでは、import_meta_graphが既存のデフォルトグラフに追加されます。名前の競合がありますので、それが今、あなたのグラフは第二のリストア時には、ノードdummy_graph/adummy_graph/a_1と、対応する初期化ノードinitinit_1

を持つことになり、ノードに_1を追加します、あなたのセーバーがdummy_graph/aにあなたの[1, 1, ...チェックポイントを復元します。次に、initと呼び出して、dummy_graph/aの値を[0, 0, ...に上書きします。 dummy_graph/a

2回目の復元後にセッションに2つの変数があり、2つ目の変数が初期化されていないことに注意してください。奇妙なことに、tf.report_uninitialized_variables()sess.run('dummy_graph/a_1:0')Attempting to use uninitialized valueというエラーをスローしても、それはバグのようです。

+0

2番目のものが初期化されないのはなぜですか?直感的には重複( '_1')変数が追加されますが、新しい値では復元されません。代わりに古い変数が上書きされて新しい変数は初期化されません – Priyatham

+1

2番目の変数は' init_1'で初期化されますが、 'init'を呼び出す –

+0

変数が' dummy/a'と名前をつけているとsaverが考えるので、新しい値でリストアされません。 'import_meta_graph'の間に起こった' a_1'への名前変更を認識していません。 IMHO '_1'を静かに追加する代わりにエラーがあった場合、混乱は少なくなりますが、' _1'の追加が一部の低レベルAPIであるため実装が難しいかもしれません –

関連する問題