2017-01-21 27 views
0
rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size) 
state = rnn_cell.zero_state(batch_size, tf.float32) 
init = tf.global_variables_initializer() 
sess = tf.Session() 
for i in range(len(x_data)): 
    x = process_x(x_data[i])[:std_size] 
    y = word[i][:std_size] 
    x_split = tf.split(0, time_step_size, x) 
    outputs, state = tf.nn.rnn(rnn_cell, x_split, state) 

    prediction = tf.reshape(tf.concat(1, outputs), [-1, rnn_size]) 
    real = tf.reshape(y, [-1]) 
    ratio = tf.ones([time_step_size * batch_size]) 

    loss = tf.nn.seq2seq.sequence_loss_by_example([prediction], [real], [ratio]) 
    cost = tf.reduce_mean(loss)/batch_size 
    train = tf.train.AdamOptimizer(0.01).minimize(cost) 

    tf.global_variables_initializer().run(session=sess) 
    step = 0 
    print state 
    while step < 1000: 
     sess.run(train) 
     step+=1 
    result = sess.run(tf.arg_max(prediction, 1)) 
    print result, [t for t in result] == y 
    tf.get_variable_scope().reuse_variables() 

ソースコードが上記のような場合、forループのすべてのステップでrnn_cellとstateが初期化されますか?
他のトレーニングケースで状態を使用する場合は、再利用する必要があります。したがって、rnn_cellとstateは、最初はそれ以降では初期化されません。
このコードの動作を想像できません。Tensorflow変数の初期化

答えて

0

私は、あなたの計算グラフ部分をセッション実行部分と分けなければならないと思います。現在あなたがやっていることは、テンソルフローが通常どのように働くかではありません。おそらくこれを試してみてください:

rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_size) 
state = rnn_cell.zero_state(batch_size, tf.float32) 
x_split = tf.split(0, time_step_size, x) 
outputs, state = tf.nn.rnn(rnn_cell, x_split, state) 

prediction = tf.reshape(tf.concat(1, outputs), [-1, rnn_size]) 
real = tf.reshape(y, [-1]) 
ratio = tf.ones([time_step_size * batch_size]) 

loss = tf.nn.seq2seq.sequence_loss_by_example([prediction], [real], [ratio]) 
cost = tf.reduce_mean(loss)/batch_size 
train = tf.train.AdamOptimizer(0.01).minimize(cost) 

init = tf.global_variables_initializer() 
sess = tf.Session() 
sess.run(init) 

for i in range(len(x_data)): 
    x = process_x(x_data[i])[:std_size] 
    y = word[i][:std_size] 

step = 0 
while step < 1000: 
    sess.run(train, feed_dict={x_split:x, real:y}) 
    step+=1 
    result = sess.run(tf.arg_max(prediction, 1)) 
    print result, [t for t in result] == y 

あなたのコードにはデザイン上の問題があるかもしれませんが、あなたのコードはあなたの「トレーニング」と分かれています。

関連する問題