Efficiently multiply elements of each row together

后端 未结 2 893
半阙折子戏
半阙折子戏 2020-12-07 04:27

Given a ndarray of size (n, 3) with n around 1000, how to multiply together all elements for each row, fast? The (inelegant) second solution below

相关标签:
2条回答
  • 2020-12-07 05:05

    np.prod accepts an axis argument:

    np.prod(a, axis=1)
    

    With axis=1, the column-wise product is computed for each row.

    Sanity check

    assert np.array_equal(np.prod(a, axis=1), prod1(a))
    

    Performance

    17.6 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    (1000x speedup)

    0 讨论(0)
  • 2020-12-07 05:15

    Improving performance further

    At first a general rule of thumb. You are working with numerical arrays, so use arrays and not lists. Lists may look somewhat like a general array, but beeing completely different in the backend and absolutely not suteable for most numerical calculations.

    If you write a simple code using Numpy-Arrays you can gain performance by simply jitting it as shown beyond. If you use lists you can more or less rewrite your code.

    import numpy as np
    import numba as nb
    
    @nb.njit(fastmath=True)
    def prod(array):
      assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
      res=np.empty(array.shape[0],dtype=array.dtype)
      for i in range(array.shape[0]):
        res[i]=array[i,0]*array[i,1]*array[i,2]
    
      return res
    

    Using np.prod(a, axis=1) isn't a bad idea, but the performance isn't really good. For an array with only 1000x3 the function call overhead is quite significant. This can be completely avoided, when using the jitted prod function in another jitted function.

    Benchmarks

    # The first call to the jitted function takes about 200ms compilation overhead. 
    #If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
    n=999
    prod1   = 795  µs
    prod2   = 187  µs
    np.prod = 7.42 µs
    prod      0.85 µs
    
    n=9990
    prod1   = 7863 µs
    prod2   = 1810 µs
    np.prod = 50.5 µs
    prod      2.96 µs
    
    0 讨论(0)
提交回复
热议问题