私はUdacityのLSTMチュートリアルに従っていますが、LSTMの入力データフォーマットを理解するのは苦労しています。 https://github.com/rndbrtrnd/udacity-deep-learning/blob/master/6_lstm.ipynbTensorflow LSTMの入力フォーマットの例batches2string
以下のコードでnum_unrollingについて説明できますか?またはLSTMモデルのトレーニングバッチを生成する方法は?
batch_size=64
num_unrollings=10
class BatchGenerator(object):
def __init__(self, text, batch_size, num_unrollings):
self._text = text
self._text_size = len(text)
self._batch_size = batch_size
self._num_unrollings = num_unrollings
segment = self._text_size // batch_size
self._cursor = [ offset * segment for offset in range(batch_size)]
self._last_batch = self._next_batch()
def _next_batch(self):
"""Generate a single batch from the current cursor position in the data."""
batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)
for b in range(self._batch_size):
batch[b, char2id(self._text[self._cursor[b]])] = 1.0
self._cursor[b] = (self._cursor[b] + 1) % self._text_size
return batch
def next(self):
"""Generate the next array of batches from the data. The array consists of
the last batch of the previous array, followed by num_unrollings new ones.
"""
batches = [self._last_batch]
for step in range(self._num_unrollings):
batches.append(self._next_batch())
self._last_batch = batches[-1]
return batches
def characters(probabilities):
"""Turn a 1-hot encoding or a probability distribution over the possible
characters back into its (most likely) character representation."""
return [id2char(c) for c in np.argmax(probabilities, 1)]
def batches2string(batches):
"""Convert a sequence of batches back into their (most likely) string
representation."""
s = [''] * batches[0].shape[0]
for b in batches:
s = [''.join(x) for x in zip(s, characters(b))]
return s
train_batches = BatchGenerator(train_text, batch_size, num_unrollings)
valid_batches = BatchGenerator(valid_text, 1, 1)
print(batches2string(train_batches.next()))
print(batches2string(train_batches.next()))
カーソルがあります。しかし、なぜ64個のバッチの最初の10文字(num_unrolling)以外のテキストを捨てるのでしょうか?
入力形式の理解に役立つリソースや例を教えてください。ありがとう!
徹底的な回答をお寄せいただきありがとうございます。私はこれに私の脳を掛けていた。 'batches2string(train_batches.next())'は、期待される順番(リスト、 'vocabulary_size' one-hotsの' num_unrollings + 1'文字の 'batch_order'長さ)を示します。 しかし、 'train_batches.next()'はちょうど畳み込まれているようです。これは 'num_unrollings + 1'のリストで、各メンバーは文字の' vocabulary_size'ワンホットリストの 'batch_size'リストを含む配列です。 – GAEfan
この詳細な回答をありがとうございました。それは素晴らしい。最後の5文字 - >最後の1文字 – Aaron