How does shuffle = 'batch' argument of the .fit() layer work in the background?

白昼怎懂夜的黑 提交于 2021-02-07 14:25:30

问题


When I train the model using the .fit() layer there is the argument shuffle preset to True.

Let's say that my dataset has 100 samples and that the batch size is 10. When I set shuffle = True then keras first randomly selects randomly the samples (now the 100 samples have a different order) and on the new order it will start creating the batches: batch 1: 1-10, batch 2: 11-20 etc.

If I set shuffle = 'batch' how is it supposed to work in the background? Intuitively and using the previous example of 100 samples dataset with batch size = 10 my guess would be that keras first allocates the samples to the batches (i.e. batch 1: samples 1-10 following the dataset original order, batch 2: 11-20 following the dataset original order as well, batch 3 ... so on so forth) and then shuffles the order of the batches. So the model now will be trained on the randomly ordered batches say for example: 3 (contains samples 21 - 30), 4 (contains samples 31 - 40), 7 (contains samples 61 - 70), 1 (contains samples 1 - 10), ... (I made up the order of the batches).

Is my thinking right or am I missing something?

Thanks!


回答1:


Looking at the implementation at this link (line 349 of training.py) the answer seems to be positive.

Try this code for checking:

import numpy as np
def batch_shuffle(index_array, batch_size):
    """Shuffles an array in a batch-wise fashion.
    Useful for shuffling HDF5 arrays
    (where one cannot access arbitrary indices).
    # Arguments
        index_array: array of indices to be shuffled.
        batch_size: integer.
    # Returns
        The `index_array` array, shuffled in a batch-wise fashion.
    """
    batch_count = int(len(index_array) / batch_size)
    # to reshape we need to be cleanly divisible by batch size
    # we stash extra items and reappend them after shuffling
    last_batch = index_array[batch_count * batch_size:]
    index_array = index_array[:batch_count * batch_size]
    index_array = index_array.reshape((batch_count, batch_size))
    np.random.shuffle(index_array)
    index_array = index_array.flatten()
    return np.append(index_array, last_batch)


x = np.array(range(100))
x_s = batch_shuffle(x,10)


来源:https://stackoverflow.com/questions/45567692/how-does-shuffle-batch-argument-of-the-fit-layer-work-in-the-background

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!