Correct way to split data to batches for Keras stateful RNNs

折月煮酒 提交于 2019-12-04 07:28:46

Based on this answer, for which I performed some tests.

Stateful=False:

Normally (stateful=False), you have one batch with many sequences:

batch_x = [
            [[0],[1],[2],[3],[4],[5]],
            [[1],[2],[3],[4],[5],[6]],
            [[2],[3],[4],[5],[6],[7]],
            [[3],[4],[5],[6],[7],[8]]
          ]

The shape is (4,6,1). This means that you have:

  • 1 batch
  • 4 individual sequences = this is batch size and it can vary
  • 6 steps per sequence
  • 1 feature per step

Every time you train, either if you repeat this batch or if you pass a new one, it will see individual sequences. Every sequence is a unique entry.

Stateful=True:

When you go to a stateful layer, You are not going to pass individual sequences anymore. You are going to pass very long sequences divided in small batches. You will need more batches:

batch_x1 = [
             [[0],[1],[2]],
             [[1],[2],[3]],
             [[2],[3],[4]],
             [[3],[4],[5]]
           ]
batch_x2 = [
             [[3],[4],[5]], #continuation of batch_x1[0]
             [[4],[5],[6]], #continuation of batch_x1[1]
             [[5],[6],[7]], #continuation of batch_x1[2]
             [[6],[7],[8]]  #continuation of batch_x1[3]
           ]

Both shapes are (4,3,1). And this means that you have:

  • 2 batches
  • 4 individual sequences = this is batch size and it must be constant
  • 6 steps per sequence (3 steps in each batch)
  • 1 feature per step

The stateful layers are meant to huge sequences, long enough to exceed your memory or your available time for some task. Then you slice your sequences and process them in parts. There is no difference in the results, the layer is not smarter or has additional capabilities. It just doesn't consider that the sequences have ended after it processes one batch. It expects the continuation of those sequences.

In this case, you decide yourself when the sequences have ended and call model.reset_states() manually.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!