2017-05-03 10 views
0

私は今このペットの問題に苦しんでいますので、助けてください!LSTM Tensorflowモデルではシーケンスを考慮していません

私は、いくつかのランダムな列を持つcsvファイルと、最初の列の最後の数値の合計に基づく最終列を持っています。私はLSTMモデルを使用してこの構造を捕捉しようとしています。つまり、最初の数桁の最後の列を予測します。

は、ここで私が使用してきたモデルだ:私は、ランダムな数字といくつかのスプレッドシートを試してみた

# Generate test data 

train_input = train_input.reshape(m, n_input, 1) # is nr of rows, n_input is number of input columns 

NUM_EXAMPLES = int(m * training_size) 

test_input = train_input[NUM_EXAMPLES:] 
test_output = train_output[NUM_EXAMPLES:] 

train_input = train_input[:NUM_EXAMPLES] 
train_output = train_output[:NUM_EXAMPLES] 
# 
# # Design model 
# 
data = tf.placeholder(tf.float32, [None, n_input, 1]) 
target = tf.placeholder(tf.float32, [None, n_classes]) 

num_hidden = 24 
cell = tf.contrib.rnn.LSTMCell(num_hidden, state_is_tuple=True) 

val, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32) 

val = tf.transpose(val, [1, 0, 2]) 
last = tf.gather(val, int(val.get_shape()[0]) - 1) 

weight = tf.Variable(tf.truncated_normal([num_hidden, int(target.get_shape()[1])])) 
bias = tf.Variable(tf.constant(0.1, shape=[target.get_shape()[1]])) 

prediction = tf.nn.softmax(tf.matmul(last, weight) + bias) 

cross_entropy = -tf.reduce_sum(target * tf.log(tf.clip_by_value(prediction,1e-10,1.0))) 

optimizer = tf.train.AdamOptimizer() 
minimize = optimizer.minimize(cross_entropy) 

mistakes = tf.not_equal(tf.argmax(target, 1), tf.argmax(prediction, 1)) 
error = tf.reduce_mean(tf.cast(mistakes, tf.float32)) 

init_op = tf.global_variables_initializer() 
sess = tf.Session() 
sess.run(init_op) 

no_of_batches = int(len(train_input)/batch_size) 
for i in range(epoch): 
    ptr = 0 
    for j in range(no_of_batches): 
     inp, out = train_input[ptr:ptr+batch_size], train_output[ptr:ptr+batch_size] 
     ptr+=batch_size 
     sess.run(minimize,{data: inp, target: out}) 
    print("Epoch - {}".format(i)) 
incorrect = sess.run(error,{data: test_input, target: test_output}) 
print('Epoch {:2d} error {:3.1f}%'.format(i + 1, 100 * incorrect)) 
sess.close() 

、と私は一貫して周りの83%の誤り率を取得しています。一方、このアルゴリズムは、ターゲット列が連続していないかどうかを知ることができます。

ありがとうございます!

答えて

0

私は明確にあなたのポイントを得ることができません、あなたはこのようなcsvファイルを持っているのですか? sum(vn1+...+vnn)に基づい

x1 x2 x3 x4 ... xn 
v11 v21 v31 v41 ... vn1 
v12 v22 v32 v42 ... vn2 
... 
v1n v2n v3n v4n ... vnn 
y1 y2 y3 y4 ... yn 

yna * sum(V) + bのように?

+0

遅くとも申し訳ありませんが、ynはx1列の最後の5つのエントリのみに基づいています。 yn = x1(n-5)+ x1(n-4)... + x1nの場合、ynの最初の4つのエントリは0です。 – Akubara

関連する問題