2017-10-05 10 views
2

テンソルフローを使用して、異なる変数スコープの同じネットワークから同一の重みを共有しようとしています。しかし、これを行う簡単な方法はないようです。私は、より大きなサブネットワークをどうしたいどのような小規模で説明するために、小さなコードサンプルを用意していますテンソルフローの異なるスコープにわたるサブネットワーク重みの共有

import tensorflow as tf 

graph = tf.Graph() 
with graph.as_default(): 
    with tf.variable_scope("super_scope_one"): 
     scope1 = tf.variable_scope("sub_scope_one") 
     with scope1: 
      number_one = tf.get_variable("number_one", shape=[1], 
             initializer=tf.ones_initializer) 
    with tf.variable_scope("super_scope_two"): 
     with tf.variable_scope("sub_scope_one", reuse=True) as scope2: 
      # Here is the problem. 
      # scope1.reuse_variables() # this crashes too if reuse=None. 
      number_one = tf.get_variable("number_one", shape=[1]) 
     with tf.variable_scope("sub_scope_two"): 
      number_two = tf.get_variable("number_two", shape=[1], 
             initializer=tf.ones_initializer) 
     number_three = number_one + number_two 

    init_op = tf.global_variables_initializer() 

with tf.Session(graph=graph): 
    init_op.run() 
    print(number_three.eval()) 

は スコープを削除せずに、二つのサブスコープで変数を共有する方法はあります上記?そして、それがなぜ悪いアイディアになるのかという正当な理由はありませんか?

+0

それが代わりにやっている何を期待している行動があると[2]? – Engineero

+0

私は現在、上記の例に似たネストされたスコープでtf.get_variableを呼び出すいくつかの同一のpretrained tf.slimサブネットワークを使用しています。現時点では、私はウェイトを共有しておらず、すべて同じネットワークを初期化しています。私はメモリ消費量を減らしたいですが、どのように知っているのですか? – v0lta

答えて

1

number_one"super_scope_one"に一度だけ定義して"super_scope_two"でも使用できます。

異なるスコープの2つの変数を一緒に使用できます。以下を参照してください:

import tensorflow as tf 

graph = tf.Graph() 
with graph.as_default(): 
    with tf.variable_scope("super_scope_one"): 
     scope1 = tf.variable_scope("sub_scope_one") 
     with scope1: 
      number_one = tf.get_variable("number_one", shape=[1], 
             initializer=tf.ones_initializer) 
    with tf.variable_scope("super_scope_two"): 
     with tf.variable_scope("sub_scope_two"): 
      number_two = tf.get_variable("number_two", shape=[1], 
             initializer=tf.ones_initializer) 
     number_three = number_one + number_two 

    init_op = tf.global_variables_initializer() 

    with tf.Session(graph=graph): 
     init_op.run() 
     print(number_three.eval()) 

戻り

+0

あなたの返信ありがとう!あなたは当然ですが、get_variableをそのまま残しながらこれを行う方法についても考えてください。私は "super_scope_two"の下の "sub_scope_one"でscope1.reuse_variables()を呼び出そうとしましたが、クラッシュしました。それは私に多くの仕事を節約するでしょう。 – v0lta

関連する問題