2017-04-13 17 views
1

私はTensorflowでLSTMニューラルネットワークを作成しています。Tensorflow LSTM - LSTMセルの行列乗算

入力テンソルのサイズは92

import tensorflow as tf 
from tensorflow.contrib import rnn 
import data 

test_x, train_x, test_y, train_y = data.get() 

# Parameters 
learning_rate = 0.001 
epochs = 100 
batch_size = 64 
display_step = 10 

# Network Parameters 
n_input = 28 # input size 
n_hidden = 128 # number of hidden layers 
n_classes = 20 # output size 

# Placeholders 
x = tf.placeholder(dtype=tf.float32, shape=[None, n_input]) 
y = tf.placeholder(dtype=tf.float32, shape=[None, n_classes]) 

# Network 
def LSTM(x): 
    W = tf.Variable(tf.random_normal([n_hidden, n_classes]), dtype=tf.float32) # weights 
    b = tf.Variable(tf.random_normal([n_classes]), dtype=tf.float32) # biases 

    x_shape = 92 

    x = tf.transpose(x) 
    x = tf.reshape(x, [-1, n_input]) 
    x = tf.split(x, x_shape) 

    lstm = rnn.BasicLSTMCell(
     num_units=n_hidden, 
     forget_bias=1.0 
    ) 
    outputs, states = rnn.static_rnn(
     cell=lstm, 
     inputs=x, 
     dtype=tf.float32 
    ) 

    output = tf.matmul(outputs[-1], W) + b 

    return output 

# Train Network 
def train(x): 
    prediction = LSTM(x) 

    with tf.Session() as sess: 
     sess.run(tf.global_variables_initializer()) 
     output = sess.run(prediction, feed_dict={"x": train_x}) 
     print(output) 

train(x) 

である私はすべてのエラーを取得していないんだけど、私はサイズ92の入力テンソルを供給しています、とLSTM関数の行列乗算は含むリストを返します。所望の量が92である場合の1つの結果ベクトル、1入力あたり1つの結果ベクトル。

は、私は出力配列内の最後の項目のみをマトリックス乗算てる問題ですか?

output = tf.matmul(outputs[-1], W) + b 

代わりに:このように、最も単純なリカレントニューラルネットを作るための

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul' (op: 'MatMul') with input shapes: [92,?,128], [128,20]. 

答えて

0

static_rnn:

output = tf.matmul(outputs, W) + b 

これは、私は後者を行うときに私が取得エラーです。 Here's the tf documentation。それへの入力はテンソルのシーケンスでなければなりません。たとえば、 "Hi"、 "How"、 "Are"、 "You"という4つの単語を入力したいとします。だからあなたの入力プレースホルダーは、各単語に対応する4つのn(各入力ベクトルのサイズ)の次元ベクトルで構成されるべきです。

あなたのプレースホルダに何か問題があると思います。 RNNへの入力数で初期化する必要があります。 28は各ベクトルの次元数です。私は92が配列の長さだと信じています。 (より多くの92のlstm細胞のような)

出力リストでは、隠されたユニットの数に等しいサイズのそれぞれのシーケンスの長さに等しいベクトルのセットを得るでしょう。

関連する問題