2016-03-07 21 views
5

長さの異なるシーケンスのバッチに対してLSTMセルをトレーニングしています。 tf.nn.rnnには非常に便利なパラメータsequence_lengthがありますが、それを呼び出した後、バッチ内の各項目の最後のタイムステップに対応する出力行を選択する方法がわかりません。次のように私のコードは基本的にテンソルフローから最後の有効な出力値を取得する方法RNN

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size) 
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths) 

lstm_outputsは、各時間ステップにおけるLSTM出力とのリストです。ただし、バッチ内の各アイテムの長さが異なるため、バッチ内の各アイテムに有効な最後のLSTM出力を含むテンソルを作成したいと考えています。

私はnumpyのインデックスを使用することができれば、私はちょうどこのようなものだろう:

all_outputs = tf.pack(lstm_outputs) 
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :] 

をしかし、それは(私がfeature requestの承知している)時間tensorflowがそれをサポートしていない始めることが判明。

どうすればこれらの値を取得できますか?

答えて

5

danijarが機能要求ページでより適切な回避策を公開しました。質問にリンクしました。それは大きなプラスであるテンソルを評価する必要はありません。

テンソルフロー0.8で動作させました。ここでは、コードは次のようになります。

def extract_last_relevant(outputs, length): 
    """ 
    Args: 
     outputs: [Tensor(batch_size, output_neurons)]: A list containing the output 
      activations of each in the batch for each time step as returned by 
      tensorflow.models.rnn.rnn. 
     length: Tensor(batch_size): The used sequence length of each example in the 
      batch with all later time steps being zeros. Should be of type tf.int32. 

    Returns: 
     Tensor(batch_size, output_neurons): The last relevant output activation for 
      each example in the batch. 
    """ 
    output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2]) 
    # Query shape. 
    batch_size = tf.shape(output)[0] 
    max_length = int(output.get_shape()[1]) 
    num_neurons = int(output.get_shape()[2]) 
    # Index into flattened array as a workaround. 
    index = tf.range(0, batch_size) * max_length + (length - 1) 
    flat = tf.reshape(output, [-1, num_neurons]) 
    relevant = tf.gather(flat, index) 
    return relevant 
2

最高の解決策ではありませんが、出力を評価してnumpyインデックスを使用して結果を取得し、そこからテンソル変数を作成できますか?テンソルフローがこの特徴を得るまで、ストップギャップとして機能するかもしれません。例えば

all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'}) 
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :] 
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs) 
+0

ええ、それは間違いなく素敵な解決策ではないでしょう。しかし、今私は他の方法を見ることはできません。 – erickrf

+1

もっと良い方法はありますか? – Zhao

1

あなたが最後の有効な出力にのみ興味があるなら、あなたはそれが常にcが最後の状態であるタプル(C、H)だということを考慮tf.nn.rnn()によって返された状態を経て、それを取得することができますhは最後の出力です。状態はあなたが(tensorflow 0.12で働いて)次のコードを使用することができますLSTMStateTupleのとき:

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size) 
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths) 
last_output = state[1] 
+0

これはテンソル0.12以降の受け入れられた答えよりはるかに簡単です – erickrf

関連する問題