2016-12-20 16 views
1

tf.get_collection()を使用するとRNNセルが表示されません。私は何が欠けていますか?tf.get_collectionにRNNセルがありません

import tensorflow as tf 
print(tf.__version__) 

rnn_cell = tf.nn.rnn_cell.LSTMCell(16) 
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 

other_var = tf.Variable(0) 
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 

プリントが出

0.12.0 
[] 
[<tensorflow.python.ops.variables.Variable object at 0x0000027961250B70>] 

のWindows 10、Pythonの3.5

+1

(私はbatch_size=10rnn_size=16を仮定している)の代わりにこれを試してみてください – martianwars

答えて

2

あなたは、あなたの変数が表示されない理由であるLSTMCell__call__を実行していません。あなたは、その変数を作成するために `` LSTMCell` __call__`にする必要があります

import tensorflow as tf 
print(tf.__version__) 

rnn_cell = tf.nn.rnn_cell.LSTMCell(16) 
a = tf.placeholder(tf.float32, [10, 16]) 
zero = rnn_cell.zero_state(10,tf.float32) 
# The variables are created in the following __call__ 
b = rnn_cell(a, zero) 
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 

other_var = tf.Variable(0) 
print(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 
関連する問題