2017-02-03 8 views
1

Networkクラス内の以下のメソッドを使用して、事前に訓練されたネットワークをTensorflowにロードします(したがって、self.xyzへの呼び出し)。まず、define_network()が呼び出され、次に他の変数とオプティマイザの初期化が行われ、次にload_model()が呼び出されます。Tensorflowの特定のスコープへのグラフのロード

ただし、tf.variable_scope(self.name)を使用していても、グラフの変数は変数の汎用スペースにロードされます。これは、このクラスの2つのインスタンスがそれぞれ同じネットワークにロードされ、異なるスコープに分割するために問題があります。

どのように変数を特定のスコープにロードできますか?

P.S.コード内のエラーで私を修正してください!

def load_model(self): 
    with tf.variable_scope(self.name) as scope: 
     self.saver.restore(self.sess, self.model_path) 
     print("Loaded model from {}".format(self.model_path)) 

    def define_model(self): 

    with tf.variable_scope(self.name) as scope: 
     self.saver = tf.train.import_meta_graph(self.model_path + '.meta') 
     print("Loaded model from {}".format(self.model_path + '.meta')) 
     graph = tf.get_default_graph() 

     self.inputs = [] 
     inp_names = ['i_hand1:0', 'i_hand2:0', 'i_flop1:0', 'i_flop2:0', 'i_flop3:0', 
        'i_turn:0', 'i_river:0', 'i_other:0', 'i_allowed_mod:0', 'keras_learning_phase:0'] 
     for inp in inp_names: 
     self.inputs.append(tf.get_default_graph().get_tensor_by_name(inp)) 

     self.outputs = tf.get_default_graph().get_tensor_by_name("Tanh:0") 
     self.add_output_conversions() 

     all_vars = tf.trainable_variables() 
     for var in all_vars: 
     self.var[var.name] = var 

答えて

0

私はあなたの問題はここで私はそれを試してみたが、私は動作しませんでしたreference

+0

self.saver = tf.train.import_meta_graph(self.model_path + '.meta', 'import_scope'=self.name) 

に引数を追加することによって解決することができると思います。エラーで失敗しました。私は、この議論は、新しいスコープではなく、グラフがどのように保存されたかに関連していると思います。 – pir

+0

どのようなエラーが発生しましたか? –

関連する問題