2016-04-28 7 views
2

私は以下のような状況に陥っています:TensorFlow:変数の内容に名前でアクセス

私はすでにネットを構築し、保存しています。今、私はネットを復元し、重み行列を視覚化しようとしています。

私はすべての変数の名前を知っていますが、評価のためにセッションに渡す変数に割り当てられたPythonマーカーはありません。変数内のデータを取得するにはどうすればよいですか?ここで

は、私のコードの状況です:

dataset_params = nn_params.mnist_dataset_params 
design = nn_designs.mnist_net_A_design 
## Build Housing Object 
mnist_nn = nn_class.CNN(**dataset_params) 
mnist_nn.build_net(design['design']) 
mnist_nn.__setattr__('saved_path',saved_model) 
mnist_nn_epoch_file = saved_model+'_epochs_completed.txt' 
mnist_nn.__setattr__('epoch_file',mnist_nn_epoch_file) 


# evaluate weight variables 
session = tf.Session() 
saver = tf.train.Saver() 
session.run(tf.initialize_all_variables()) 
saver.restore(session,saved_model) 




session.close() 

私はウェイトを引き出すためにセッションに渡す必要がありますか? (重量名の例: 'conv_w_1')?

答えて

6

あなたはこれが必要な変数を取得するためにtf.get_collection() lookupメソッドを使用している行うことができます。

weight_var = tf.get_collection(tf.GraphKeys.VARIABLES, "conv_w_1")[0] 

weight_var_value = session.run(weight_var) 
0

をそれとも、(関数tf.get_default_graphを使用して結果を得ることができます)get_tensor_by_name:

valua_of_conv_w_1 = session.run(tf.get_default_graph().get_tensor_by_name("conv_w_1:0")) 
関連する問題