2016-10-14 7 views
0

私はテンソルフローで生成的なRNNモデルを作成しようとしています。 RNNライブラリでstate_is_tupeへの新しい切り替えがデフォルトでtrueになっているため、バッチ間で状態を保存する最善の方法を見つけるのが苦労しています。私はそれを虚偽に戻すことができることを知っていますが、非難されているので、私はそれをしたくありません。私がトレーニングを終えたら、session.runの呼び出しの間に隠れた状態を維持できるようにする必要があります。これは、一度に1つのサンプルを生成するためです。私は次のようにrnnの状態を返すことができることを理解しました。Tensorflow、feed_dictでMultiRNN状態を渡す方法

 rnn = tf.nn.rnn_cell.MultiRNNCell(cells) 
     zero_state = rnn.zero_state(batch_size, tf.float32) 
     output, final_state = tf.nn.dynamic_rnn(rnn, self.input_sound, initial_state = zero_state) 
     sess = tf.Session() 
     sess.run(tf.initialize_all_variables()) 
     state_output = sess.run(final_state, feed_dict = {self.input_sound: np.zeros((64, 32, 512))}) 

これは素晴らしいことですが、state_outputをモデルに戻したいときに問題が発生します。プレースホルダはテンソルオブジェクトでなければならないので、state_outputタプルに戻すことはできません。

私は非常に一般的な解決策を探しています。 rnnは、MultiRNNCellまたは単一のLSTMCell、または想像できる他の組み合わせであってもよい。

答えて

0

私はそれを理解したと思います。私は状態タプルを単一の1Dテンソルに平坦化するために次のコードを使用しました。私はそれをrnnセルのサイズ指定に従ってモデルに戻すと、それを切り詰めるよりもむしろそれを行うことができます。

def flatten_state_tupel(x): 
    result = [] 
    for x_ in x: 
     if isinstance(x_, tf.Tensor) or not hasattr(x_, '__iter__'): 
      result.append(x_) 
     else: 
      result.extend(flatten_state_tupel(x_)) 
    return result 

def pack_state_tupel(state): 
    return tf.concat(0, [tf.reshape(s, (-1,)) for s in flatten_state_tupel(state)]) 

def unpack_state_tupel(state, size): 
    state = tf.reshape(state, (-1, tf.reduce_sum(flatten_state_tupel(size)))) 
    def _make_state_tupel(sz, i): 
     if hasattr(sz, '__iter__'): 
      result = [] 
      for s in sz: 
       base_index, y = _make_state_tupel(s, i) 
       result.append(y) 
      return base_index, tf.nn.rnn_cell.LSTMStateTuple(*result) if isinstance(sz, tf.nn.rnn_cell.LSTMStateTuple) else tuple(result) 
     else: 
      return i + sz, state[..., i : i + sz] 
    return _make_state_tupel(size, 0)[-1] 

私は次のような関数を使用します。

rnn = tf.nn.rnn_cell.MultiRNNCell(cells) 
zero_state = pack_state_tupel(rnn.zero_state(batch_size, tf.float32)) 
self.initial_state = tf.placeholder_with_default(zero_state, None) 

output, final_state = tf.nn.dynamic_rnn(rnn, self.input_sound, initial_state = unpack_state_tupel(self.initial_state, rnn.state_size)) 

packed_state = pack_state_tupel(final_state) 

sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 

state_output = sess.run(packed_state, feed_dict = {self.input_sound: np.zeros((64, 32, 512))}) 
print(state_output.shape) 
state_output = sess.run(packed_state, feed_dict = {self.input_sound: np.zeros((64, 32, 512)), self.initial_state: np.zeros(state_output.shape[0])}) 
print(state_output) 

私はしかし、私は保存し、世代間、バッチ間の状態を渡すことができます(トレーニング中のケースとなります)何も渡さない場合、それは状態をゼロにしますこの方法です。

関連する問題