2017-12-11 12 views
3

、次のように得られる各時点のための機能dynamic_rnn出力(mとして知られる)のみ隠れ状態:Tensorflow:dynamic_rnnを使用してLSTMCellから中間セル状態(c)を取得するにはどうすればよいですか?デフォルトで

cell = tf.contrib.rnn.LSTMCell(100) 
rnn_outputs, _ = tf.nn.dynamic_rnn(cell, 
            inputs=inputs, 
            sequence_length=sequence_lengths, 
            dtype=tf.float32) 

はで方法得る中間体(最終ではない)セル状態(c)があります添加?

それが細胞ラッパーで行うことができますtensorflow貢献mentions

class Wrapper(tf.nn.rnn_cell.RNNCell): 
    def __init__(self, inner_cell): 
    super(Wrapper, self).__init__() 
    self._inner_cell = inner_cell 
    @property 
    def state_size(self): 
    return self._inner_cell.state_size 
    @property 
    def output_size(self): 
    return (self._inner_cell.state_size, self._inner_cell.output_size) 
    def call(self, input, state) 
    output, next_state = self._inner_cell(input, state) 
    emit_output = (next_state, output) 
    return emit_output, next_state 

しかし、動作するようには思えません。何か案は?

答えて

2

提案された解決策は私には役立ちますが、Layer.callメソッド仕様がより一般的です。そのため、次のはAPIの変更に対してより堅牢である必要があります。汝この:

n_steps = 2 
n_inputs = 3 
n_neurons = 5 

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs]) 
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False)) 
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32) 
print(outputs, states) 

X_batch = np.array([ 
    # t = 0  t = 1 
    [[0, 1, 2], [9, 8, 7]], # instance 0 
    [[3, 4, 5], [0, 0, 0]], # instance 1 
    [[6, 7, 8], [6, 5, 4]], # instance 2 
    [[9, 0, 1], [3, 2, 1]], # instance 3 
]) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    outputs_val = outputs[0].eval(feed_dict={X: X_batch}) 
    print(outputs_val) 

返さoutputsはすべてLSTM状態と出力をしている(?, 2, 10)(?, 2, 5)テンソルのタプルです:

class Wrapper(tf.nn.rnn_cell.RNNCell): 
    def __init__(self, inner_cell): 
    super(Wrapper, self).__init__() 
    self._inner_cell = inner_cell 

    @property 
    def state_size(self): 
    return self._inner_cell.state_size 

    @property 
    def output_size(self): 
    return (self._inner_cell.state_size, self._inner_cell.output_size) 

    def call(self, input, *args, **kwargs): 
    output, next_state = self._inner_cell(input, *args, **kwargs) 
    emit_output = (next_state, output) 
    return emit_output, next_state 

は、ここでのテストです。私がではなくtf.nn.rnn_cellパッケージからLSTMCellの「段階的」バージョンを使用していることに注意してください。 LSTMStateTupleを処理しないようにするには、state_is_tuple=Trueに注意してください。

0

マキシムの考えに基づき、私は次の解決策になってしまった:

class StatefulLSTMCell(LSTMCell): 
    def __init__(self, *args, **kwargs): 
     super(StatefulLSTMCell, self).__init__(*args, **kwargs) 

    @property 
    def output_size(self): 
     return (self.state_size, super(StatefulLSTMCell, self).output_size) 

    def call(self, input, state): 
     output, next_state = super(StatefulLSTMCell, self).call(input, state) 
     emit_output = (next_state, output) 
     return emit_output, next_state 
関連する問題