TensorflowでLSTMを使用して分類したい長さのシーケンスがあります。分類のためには、各シーケンスの最後のタイムステップのLSTM出力が必要です。TensorflowでのLSTMシーケンスの最終値の取得
max_length = 10
n_dims = 2
layer_units = 5
input = tf.placeholder(tf.float32, [None, max_length, n_dims])
lengths = tf.placeholder(tf.int32, [None])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)
sequence_outputs, last_states = tf.nn.dynamic_rnn(cell, sequence_length=lengths, inputs=input)
私はnumpyの表記で、取得したいと思います:output = sequence_outputs[:,lengths]
はTensorflowでこの動作を取得する方法や回避策はありますか?
--- --- UPDATE
がこの投稿をHow to select rows from a 3-D Tensor in TensorFlow?に続き、それはtf.gather
とインデックスを操作すると効率的な方法で問題を解決することができるようです。唯一の要件は、バッチサイズを事前に知っていなければならないことです。ここでは、この具体的な問題への言及ポストの適応は次のとおりです。
max_length = 10
n_dims = 2
layer_units = 5
batch_size = 2
input = tf.placeholder(tf.float32, [batch_size, max_length, n_dims])
lengths = tf.placeholder(tf.int32, [batch_size])
cell = tf.nn.rnn_cell.LSTMCell(num_units=layer_units, state_is_tuple=True)
sequence_outputs, last_states = tf.nn.dynamic_rnn(cell,
sequence_length=lengths, inputs=input)
#Code adapted from @mrry response in StackOverflow:
#https://stackoverflow.com/questions/36088277/how-to-select-rows-from-a-3-d-tensor-in-tensorflow
rows_per_batch = tf.shape(input)[1]
indices_per_batch = 1
# Offset to add to each row in indices. We use `tf.expand_dims()` to make
# this broadcast appropriately.
offset = tf.range(0, batch_size) * rows_per_batch
# Convert indices and logits into appropriate form for `tf.gather()`.
flattened_indices = lengths - 1 + offset
flattened_sequence_outputs = tf.reshape(self.sequence_outputs, tf.concat(0, [[-1],
tf.shape(sequence_outputs)[2:]]))
selected_rows = tf.gather(flattened_sequence_outputs, flattened_indices)
last_output = tf.reshape(selected_rows,
tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
tf.shape(self.sequence_outputs)[2:]]))
@petruxオプション(Get the last output of a dynamic_rnn in TensorFlow)が動作するようにも思えるが、私はしなかったものの、forループ内のリストを構築する必要性が少なく、最適化することができますこのステートメントをサポートするためのベンチマークを実行してください。
うまくいくようです。私が探していた事前のインデックス作成は未解決の問題であるようです:https://github.com/tensorflow/tensorflow/issues/4638 あなたの実装が効率を大幅に低下させるか、それが妥当かどうか疑問に思います。 – pabaldonedo