Speeding up element-wise array multiplication in python

后端 未结 5 1573
小鲜肉
小鲜肉 2020-12-08 11:32

I have been playing around with numba and numexpr trying to speed up a simple element-wise matrix multiplication. I have not been able to get better results, they both are b

5条回答
  •  挽巷
    挽巷 (楼主)
    2020-12-08 12:09

    How are you doing your timings ?

    The creation of your random array is taking up the overal part of your calculation, and if you include it in your timing you will hardly see any real difference in the results, however, if you create it up front you can actually compare the methods.

    Here are my results, and I'm consistently seeing what you are seeing. numpy and numba give about the same results (with numba being a little bit faster.)

    (I don't have numexpr available)

    In [1]: import numpy as np
    In [2]: from numba import autojit
    In [3]: a=np.random.rand(10,5000000)
    
    In [4]: %timeit multiplication1 = np.multiply(a,a)
    10 loops, best of 3: 90 ms per loop
    
    In [5]: # numba
    
    In [6]: def multiplix(X,Y):
       ...:         M = X.shape[0]
       ...:         N = X.shape[1]
       ...:         D = np.empty((M, N), dtype=np.float)
       ...:         for i in range(M):
       ...:                 for j in range(N):
       ...:                         D[i,j] = X[i, j] * Y[i, j]
       ...:         return D
       ...:         
    
    In [7]: mul = autojit(multiplix)
    
    In [26]: %timeit multiplication1 = np.multiply(a,a)
    10 loops, best of 3: 182 ms per loop
    
    In [27]: %timeit multiplication1 = np.multiply(a,a)
    10 loops, best of 3: 185 ms per loop
    
    In [28]: %timeit multiplication1 = np.multiply(a,a)
    10 loops, best of 3: 181 ms per loop
    
    In [29]: %timeit multiplication2 = mul(a,a)
    10 loops, best of 3: 179 ms per loop
    
    In [30]: %timeit multiplication2 = mul(a,a)
    10 loops, best of 3: 180 ms per loop
    
    In [31]: %timeit multiplication2 = mul(a,a)
    10 loops, best of 3: 178 ms per loop
    

    Update: I used the latest version of numba, just compiled it from source: '0.11.0-3-gea20d11-dirty'

    I tested this with the default numpy in Fedora 19, '1.7.1' and numpy '1.6.1' compiled from source, linked against:

    Update3 My earlier results were of course incorrect, I had return D in the inner loop, so skipping 90% of the calculations.

    This provides more evidence for ali_m's assumption that it is really hard to do better than the already very optimized c code.

    However, if you are trying to do something more complicated, e.g.,

    np.sqrt(((X[:, None, :] - X) ** 2).sum(-1))
    

    I can reproduce the figures Jake Vanderplas get's:

    In [14]: %timeit pairwise_numba(X)
    10000 loops, best of 3: 92.6 us per loop
    
    In [15]: %timeit pairwise_numpy(X)
    1000 loops, best of 3: 662 us per loop
    

    So it seems you are doing something that has been so far optimized by numpy it is hard to do any better.

提交回复
热议问题