Efficient outer product in python

后端 未结 3 1687
刺人心
刺人心 2020-11-30 12:41

Outer product in python seems quite slow when we have to deal with vectors of dimension of order 10k. Could someone please give me some idea how could I speed up this opera

3条回答
  •  陌清茗
    陌清茗 (楼主)
    2020-11-30 13:30

    @elyase's answer is great, and rightly accepted. Here's an additional suggestion that, if you can use it, might make the call to np.outer even faster.

    You say "I have to do this operation several times", so it is possible that you can reuse the array that holds the outer product, instead of allocating a new one each time. That can give a nice boost in performance.

    First, some random data to work with:

    In [32]: a = np.random.randn(128)
    
    In [33]: b = np.random.randn(32000)
    

    Here's the baseline timing for np.outer(a, b) on my computer:

    In [34]: %timeit np.outer(a, b)
    100 loops, best of 3: 5.52 ms per loop
    

    Suppose we're going to repeat that operation several times, with arrays of the same shape. Create an out array to hold the result:

    In [35]: out = np.empty((128, 32000))
    

    Now use out as the third argument of np.outer:

    In [36]: %timeit np.outer(a, b, out)
    100 loops, best of 3: 2.38 ms per loop
    

    So you get a nice performance boost if you can reuse the array that holds the outer product.

    You get a similar benefit if you use the out argument of einsum, and in the cython function if you add a third argument for the output instead of allocating it in the function with np.empty. (The other compiled/jitted codes in @elyase's answer will probably benefit from this, too, but I only tried the cython version.)

    Nota bene! The benefit shown above might not be realized in practice. The out array fits in the L3 cache of my CPU, and when it is used in the loop performed by the timeit command, it likely remains in the cache. In practice, the array might be moved out of the cache between calls to np.outer. In that case, the improvement isn't so dramatic, but it should still be at least the cost of a call to np.empty(), i.e.

    In [53]: %timeit np.empty((128, 32000))
    1000 loops, best of 3: 1.29 ms per loop
    

提交回复
热议问题