2016-11-28 12 views
1

私は、Tensorflowがグラフを管理しているように見える方法で、いくつかのファンキーなことが起こっていることを認識しました。Tensorflowはどのようにグラフを管理しますか?

モデルのビルド(および再構築)が面倒なので、カスタムモデルをクラスにラップして他の場所で簡単にインスタンス化できるようにしました。

私は訓練をしていて、元の場所でコードをテストしていましたが、グラフの変数をロードしたコードでは変わった再定義やその他すべての奇妙なエラーが発生します。これは(似たようなことについての私の最後の質問から)、すべてが2回呼び出されたというヒントでした。

トレースのトーンを実行した後、ロードされたコードを使用していた方法になりました。それはそう

class MyModelUser(object): 
    def forecast(self): 
     # .. build the model in the same way as in the training code 
     # load the model checkpoint 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

のような構造を持っていたし、いくつかのコードでは、私は

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

を持っていたし、(明らかに)私は2つの予測これを見ると期待MyModelUser使用して、そのクラスの中から使用されていましたと呼ばれていました。代わりに、最初の予測が呼ばれ、期待通りに働いたが、2回目の呼び出しは、変数の再利用のTONとValueErrorこれらのいずれかの例を投げたのだった。

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope? 

私はシリーズを追加することで、エラーを鎮めるために管理変数を作成するためにget_variableを使用し、スコープでreuse_variablesと呼ばれ、例外としてget_variableと名前を付けずにtry/exceptブロックを削除しました。私が言った気まぐれで

tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files 

:これはだったそのうちの一つの厄介なエラーの新しいセット、上もたらした「私は__init__にモデリング建物のコードを移動した場合ので、その一度だけ組み込まれて何?」

私の新しいモデルのユーザー:今

class MyModelUser(object): 
    def __init__(self): 
     # ... build the model in the same way as in the training code 
     # load the model checkpoint 


    def forecast(self): 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

と:期待など

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

作品、エラーなしで2人の予想を印刷します。これは、私が変数を再利用するものを取り除くこともできると私につながります。

私の質問は次のとおりです。

なぜこれを修正しましたか?理論的には、元の予測方法で毎回グラフを再インスタンス化する必要があるため、複数のグラフを作成してはいけません。関数が完了した後もTensorflowはグラフを保持しますか?これで、作成コードを__init__に移動するのはなぜですか?これは私を絶望的に混乱させてしまった。

答えて

2

デフォルトでは、TensorFlowは、最初にTensorFlow APIを呼び出すときに作成される単一のグローバルtf.Graphインスタンスを使用します。tf.Graphを明示的に作成しないと、すべての操作、テンソル、および変数がデフォルトインスタンスで作成されます。つまり、コード内の各呼び出しがmodel_user.forecast()になると、同じグローバルグラフに操作が追加されますが、これはやや無駄です。

があり、ここでアクションの2つの可能性のあるコース(少なくとも):

  • 理想的なアクションはMyModelUser.__init__()が予測を実行するために必要なすべての操作で全体tf.Graphを構築するようにコードを再構築することです、 MyModelUser.forecast()は、単に既存のグラフ上でsess.run()コールを実行します。理想的には、単一のtf.Sessionも作成します。これは、TensorFlowがセッションに関するグラフに関する情報をキャッシュし、実行がより効率的になるためです。

  • 侵襲性の低い—おそらくあまり効率的—変更はMyModelUser.forecast()へのすべての呼び出しのための新しいtf.Graphを作成することです。それは非常に状態がMyModelUser.__init__()方法で作成する方法を質問からは不明だが、あなたは別のグラフに2つのコールを置くために、次のような何かができる:

    def test_the_model(self): 
        with tf.Graph(): # Create a local graph 
        model_user_1 = MyModelUser() 
        print(model_user_1.forecast()) 
        with tf.Graph(): # Create another local graph 
        model_user_2 = MyModelUser() 
        print(model_user_2.forecast()) 
    
0

TFには、新しい操作などが追加されるデフォルトのグラフがあります。関数を2回呼び出すと、同じグラフを同じグラフに2回追加します。だから、グラフを一度作成して、それを何度も評価してください(これは「通常の」アプローチでもあります)。変更したい場合は、reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graphを使ってグラフをリセットしてください新鮮な状態。

関連する問題