私は強化学習に取り組んでおり、学習中にsess.run()を使用して学習したデータの量を減らして学習をスピードアップしたいと考えています。TensorFlow:グラフ内のLSTM状態の保存/更新
私はLSTMに探していたし、楽しみにして適切なQ値を見つけるためにリセットする必要があると、私は)(tf.caseでこのようなソリューションを細工:
CurrentStateOption = tf.Variable(0, trainable=False, name='SavedState')
with tf.name_scope("LSTMLayer") as scope:
initializer = tf.random_uniform_initializer(-.1, .1)
lstm_cell_L1 = tf.nn.rnn_cell.LSTMCell(self.input_sizes, forget_bias=1.0, initializer=initializer, state_is_tuple=True)
self.cell_L1 = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_L1] *self.NumberLSTMLayers, state_is_tuple=True)
self.state = self.cell_L1.zero_state(1,tf.float64)
self.SavedState = self.cell_L1.zero_state(1,tf.float64) #tf.Variable(state, trainable=False, name='SavedState')
#SaveCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(1)), self.SaveState, self.SameState)
#RestoreCond = tf.cond(tf.equal(CurrentStateOption,tf.constant(-1)), self.RestoreState, self.SameState)
#ZeroCond = tf.cond(tf.less(CurrentStateOption,tf.constant(-1)), self.ZeroState, self.SameState)
self.state = tf.case({tf.equal(CurrentStateOption,tf.constant(1)): self.SaveState, tf.equal(CurrentStateOption,tf.constant(-1)): self.RestoreState,
tf.less(CurrentStateOption,tf.constant(-1)): self.ZeroState}, default=self.SameState, exclusive=True)
RunConditions = tf.group([SaveCond, RestoreCond, ZeroCond])
self.Xinputs = [tf.concat(1,[Xinputs])]
outputs, stateFINAL_L1 = rnn.rnn(self.cell_L1,self.Xinputs, initial_state=self.state, dtype=tf.float32)
def RestoreState(self): #self.state = self.state.assign(self.SavedState) self.state = self.SavedState return self.state def ZeroState(self): self.state = self.cell_L1.zero_state(1,tf.float64) return self.state def SaveState(self): #self.SavedState = self.SavedState.assign(self.state) self.SavedState = self.state return self.SavedState def SameState(self): return self.state
これは私がwが何をすべきかLSTMグラフを指示するINTを供給できるようになりました概念にうまく動作するようですith状態。パス "1"を実行する前に状態を保存します。 "-1"を渡すと最後に保存された状態に戻ります。 "< -1"を渡すと状態はゼロになります。 "0"の場合は、最後に実行したときのLSTM(推論)を使用します。私は単純なtf.cond()アプローチを含むいくつかの異なるアプローチを試みました。
テンソルを必要とするtf.case()オペレーションが、LSTMの状態がタプルである(タプル以外のタプルが減価償却される)から生じると考えられる問題です。これは、グラフ変数に値をtf.assign()しようとすると明らかになりました。
私の最終目標はグラフ内に「状態」を残し、INTを渡して状態をどう処理するかを指示することです。将来、私はさまざまなルックバックのために複数の "店舗"の場所を持っていたいと思います。
タフルとテンソルの構造体のtf.case()型を扱う方法はありますか?