2016-10-30 13 views
4

テンソルフローで次のバニラRNNの実装があります。 basicRNNCellからウェイトとバイアスの値を取得するにはどうすればよいですか?RNNセルからの重みと偏りの値を取得

import tensorflow as tf 
import numpy as np 

input_size = 5 
batch_size = 2 
max_length = 1 

cell = tf.nn.rnn_cell.BasicRNNCell(num_units = 4) 

# Batch size x time steps x features. 
data = tf.placeholder(tf.float32, [None, max_length, input_size]) 
output, _ = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    result = sess.run([output], feed_dict={data: np.ones((batch_size, max_length, input_size))}) 

    print result 
    print result[0].shape 

    for v in tf.trainable_variables(): 
     print v.name 
     print dir(v) 

答えて

10

名前でセッションから任意のテンソルの値を抽出することができます。

variables_names =[v.name for v in tf.trainable_variables()] 
values = sess.run(variables_names) 
for k,v in zip(variables_names, values): 
    print(k, v) 

戻り値この:

RNN/BasicRNNCell/Linear/Matrix:0 
RNN/BasicRNNCell/Linear/Bias:0 
(u'RNN/BasicRNNCell/Linear/Matrix:0', array([[ 0.0612123 , -0.3020778 , 0.39463997, 0.09347564], 
     [ 0.45926428, 0.23726827, -0.4563897 , -0.23666686], 
     [-0.45560977, -0.13659951, -0.51252407, 0.54929543], 
     [-0.54475051, -0.20766461, 0.01690435, -0.11470184], 
     [ 0.31095517, -0.5281173 , 0.50487423, -0.12220767], 
     [ 0.09355438, 0.14729732, -0.31751576, 0.39974809], 
     [ 0.30579591, -0.46520707, -0.48943958, 0.22013563], 
     [-0.08513373, 0.30004191, 0.06920779, 0.38332987], 
     [-0.36613646, -0.26537177, -0.18271935, 0.4455297 ]], dtype=float32)) 
(u'RNN/BasicRNNCell/Linear/Bias:0', array([ 0., 0., 0., 0.], dtype=float32)) 
関連する問題