2017-08-26 24 views
0

私は自分のネットワークをImageNetで訓練していますので、訓練されたウエイトのサブセットをプロジェクトに使用できます。テンソルフローの変数のサブセットを保存/復元できません

ホールウェイトの保存と復元は問題ありませんが、完全に接続されたレイヤーなしで保存しようとすると、エラーが発生します。NameError:グローバル名 'w1'は定義されていません。それは誰も助けている場合、リポジトリはgithubであるかのコードスニペット:

inference.py

... 
def inference(images): 
    w1 = tf.get_variable('w1', shape=[5,5,3,64]) 
    ... 

grasp.py

def run_training(): 
    ... 
    logits = inference(images) 
    ... 
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) 
sess = tf.Session() 
sess.run(init_op) 
saver = tf.train.Saver({'w1': w1}) 

トレースバック

Traceback (most recent call last): 
    File "./grasp.py", line 130, in <module> 
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 
    File "/usr/local/lib/python2.7/site- 
    packages/tensorflow/python/platform/app.py", line 44, in run 
    _sys.exit(main(_sys.argv[:1] + flags_passthrough)) 
    File "./grasp.py", line 83, in main 
    run_training() 
    File "./grasp.py", line 52, in run_training 
    saver = tf.train.Saver({'w1': w1}) 
NameError: global name 'w1' is not defined 

ご意見やご要望がございましたら、お知らせください。

答えて

0

tf.global_variables()の変数にアクセスする必要があります。

w1 = [v for v in tf.global_variables() if v.name == 'w1:0'][0] 
saver = tf.train.Saver({'w1': w1}) 
+0

Iは[v.name == 'W1:0' の場合tf.get_collectionでVは、V(tf.GraphKeys.GLOBAL_VARIABLES)] 'W1 =もこれを発見した' [0]。 – prometeu

関連する問題