Shift elements in a numpy array

前端 未结 8 1725
萌比男神i
萌比男神i 2020-12-01 00:54

Following-up from this question years ago, is there a canonical \"shift\" function in numpy? I don\'t see anything from the documentation.

Here\'s a simple version o

8条回答
  •  旧时难觅i
    2020-12-01 01:12

    Benchmarks & introducing Numba

    1. Summary

    • The accepted answer (scipy.ndimage.interpolation.shift) is the slowest solution listed in this page.
    • Numba (@numba.njit) gives some performance boost when array size smaller than ~25.000
    • "Any method" equally good when array size large (>250.000).
    • The fastest option really depends on
          (1)  Length of your arrays
          (2)  Amount of shift you need to do.
    • Below is the picture of the timings of all different methods listed on this page (2020-07-11), using constant shift = 10. As one can see, with small array sizes some methods are use more than +2000% time than the best method.

    Relative timings, constant shift (10), all methods

    2. Detailed benchmarks with the best options

    • Choose shift4_numba (defined below) if you want good all-arounder

    3. Code

    3.1 shift4_numba

    • Good all-arounder; max 20% wrt. to the best method with any array size
    • Best method with medium array sizes: ~ 500 < N < 20.000.
    • Caveat: Numba jit (just in time compiler) will give performance boost only if you are calling the decorated function more than once. The first call takes usually 3-4 times longer than the subsequent calls.
    import numba
    
    @numba.njit
    def shift4_numba(arr, num, fill_value=np.nan):
        if num >= 0:
            return np.concatenate((np.full(num, fill_value), arr[:-num]))
        else:
            return np.concatenate((arr[-num:], np.full(-num, fill_value)))
    

    3.2. shift5_numba

    • Best option with small (N <= 300.. 1500) array sizes. Treshold depends on needed amount of shift.
    • Good performance on any array size; max + 50% compared to the fastest solution.
    • Caveat: Numba jit (just in time compiler) will give performance boost only if you are calling the decorated function more than once. The first call takes usually 3-4 times longer than the subsequent calls.
    import numba
    
    @numba.njit
    def shift5_numba(arr, num, fill_value=np.nan):
        result = np.empty_like(arr)
        if num > 0:
            result[:num] = fill_value
            result[num:] = arr[:-num]
        elif num < 0:
            result[num:] = fill_value
            result[:num] = arr[-num:]
        else:
            result[:] = arr
        return result
    

    3.3. shift5

    • Best method with array sizes ~ 20.000 < N < 250.000
    • Same as shift5_numba, just remove the @numba.njit decorator.

    4 Appendix

    4.1 Details about used methods

    • shift_scipy: scipy.ndimage.interpolation.shift (scipy 1.4.1) - The option from accepted answer, which is clearly the slowest alternative.
    • shift1: np.roll and out[:num] xnp.nan by IronManMark20 & gzc
    • shift2: np.roll and np.put by IronManMark20
    • shift3: np.pad and slice by gzc
    • shift4: np.concatenate and np.full by chrisaycock
    • shift5: using two times result[slice] = x by chrisaycock
    • shift#_numba: @numba.njit decorated versions of the previous.

    The shift2 and shift3 contained functions that were not supported by the current numba (0.50.1).

    4.2 Other test results

    4.2.1 Relative timings, all methods

    • Relative timings, 10% shift, all methods
    • Relative timings, constant shift (10), all methods

    4.2.2 Raw timings, all methods

    • Raw timings, constant shift (10), all methods
    • Raw timings, 10% shift, all methods

    4.2.3 Raw timings, few best methods

    • Raw timings with small arrays, constant shift (10), few best methods
    • Raw timings with small arrays, 10% shift, few best methods
    • Raw timings with large arrays, constant shift (10), few best methods
    • Raw timings with large arrays, 10% shift, few best methods

提交回复
热议问题