How to understand numpy strides for layman?

前端 未结 3 699
南笙
南笙 2020-11-27 14:11

I am currently going through numpy and there is a topic in numpy called \"strides\". I understand what it is. But how does it work? I did not find any useful information onl

3条回答
  •  暖寄归人
    2020-11-27 14:33

    Just to add to great answer by @AndyK, I learnt about numpy strides from Numpy MedKit. There they show the use with a problem as follows:

    Given input:

    x = np.arange(20).reshape([4, 5])
    >>> x
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19]])
    

    Expected Output:

    array([[[  0,  1,  2,  3,  4],
            [  5,  6,  7,  8,  9]],
    
           [[  5,  6,  7,  8,  9],
            [ 10, 11, 12, 13, 14]],
    
           [[ 10, 11, 12, 13, 14],
            [ 15, 16, 17, 18, 19]]])
    

    To do this, we need to know the following terms:

    shape - The dimensions of the array along each axis.

    strides - The number of bytes of memory that must be skipped to progress to the next item along a certain dimension.

    >>> x.strides
    (20, 4)
    
    >>> np.int32().itemsize
    4
    

    Now, if we look at the Expected Output:

    array([[[  0,  1,  2,  3,  4],
            [  5,  6,  7,  8,  9]],
    
           [[  5,  6,  7,  8,  9],
            [ 10, 11, 12, 13, 14]],
    
           [[ 10, 11, 12, 13, 14],
            [ 15, 16, 17, 18, 19]]])
    

    We need to manipulate the array shape and strides. The output shape must be (3, 2, 5), i.e. 3 items, each containing two rows (m == 2) and each row having 5 elements.

    The strides need to change from (20, 4) to (20, 20, 4). Each item in the new output array starts at a new row, that each row consists of 20 bytes (5 elements of 4 bytes each), and each element occupies 4 bytes (int32).

    So:

    >>> from numpy.lib import stride_tricks
    >>> stride_tricks.as_strided(x, shape=(3, 2, 5),
                                    strides=(20, 20, 4))
    ...
    array([[[  0,  1,  2,  3,  4],
            [  5,  6,  7,  8,  9]],
    
           [[  5,  6,  7,  8,  9],
            [ 10, 11, 12, 13, 14]],
    
           [[ 10, 11, 12, 13, 14],
            [ 15, 16, 17, 18, 19]]])
    

    An alternative would be:

    >>> d = dict(x.__array_interface__)
    >>> d['shape'] = (3, 2, 5)
    >>> s['strides'] = (20, 20, 4)
    
    >>> class Arr:
    ...     __array_interface__ = d
    ...     base = x
    
    >>> np.array(Arr())
    array([[[  0,  1,  2,  3,  4],
            [  5,  6,  7,  8,  9]],
    
           [[  5,  6,  7,  8,  9],
            [ 10, 11, 12, 13, 14]],
    
           [[ 10, 11, 12, 13, 14],
            [ 15, 16, 17, 18, 19]]])
    

    I use this method very often instead of numpy.hstack or numpy.vstack and trust me, computationally it is much faster.

    Note:

    When using very large arrays with this trick, calculating the exact strides is not so trivial. I usually make a numpy.zeroes array of the desired shape and get the strides using array.strides and use this in the function stride_tricks.as_strided.

    Hope it helps!

提交回复
热议问题