Why is numpy's einsum slower than numpy's built-in functions?

前端 未结 2 1676
情深已故
情深已故 2020-12-08 22:41

I\'ve usually gotten good performance out of numpy\'s einsum function (and I like it\'s syntax). @Ophion\'s answer to this question shows that - for the cases tested - einsu

2条回答
  •  情深已故
    2020-12-08 23:30

    einsum has a specialized case for '2 operands, ndim=2'. In this case there are 3 operands, and a total of 3 dimensions. So it has to use a general nditer.

    While trying to understand how the string input is parsed, I wrote a pure Python einsum simulator, https://github.com/hpaulj/numpy-einsum/blob/master/einsum_py.py

    The (stripped down) einsum and sum-of-products functions are:

    def myeinsum(subscripts, *ops, **kwargs):
        # dropin preplacement for np.einsum (more or less)
        
        
        x = sum_of_prod(ops, op_axes, **kwargs)
        return x
    
    def sum_of_prod(ops, op_axes,...):
        ...
        it = np.nditer(ops, flags, op_flags, op_axes)
        it.operands[nop][...] = 0
        it.reset()
        for (x,y,z,w) in it:
            w[...] += x*y*z
        return it.operands[nop]
    

    Debugging output for myeinsum('ik,km,im->i',X,C,X,debug=True) with (M,K)=(10,5)

    {'max_label': 109, 
     'min_label': 105, 
     'nop': 3, 
     'shapes': [(10, 5), (5, 5), (10, 5)], 
     ....}}
     ...
    iter labels: [105, 107, 109],'ikm'
    
    op_axes [[0, 1, -1], [-1, 0, 1], [0, -1, 1], [0, -1, -1]]
    

    If you write a sum-of-prod function like this in cython you should get something close to the generalized einsum.

    With the full (M,K), this simulated einsum is 6-7x slower.


    Some timings building on the other answers:

    In [84]: timeit np.dot(X,C)
    1 loops, best of 3: 781 ms per loop
    
    In [85]: timeit np.einsum('ik,km->im',X,C)
    1 loops, best of 3: 1.28 s per loop
    
    In [86]: timeit np.einsum('im,im->i',A,X)
    10 loops, best of 3: 163 ms per loop
    

    This 'im,im->i' step is substantially faster than the other. The sum dimension,mis only 20. I suspecteinsum` is treating this as a special case.

    In [87]: timeit np.einsum('im,im->i',np.dot(X,C),X)
    1 loops, best of 3: 950 ms per loop
    
    In [88]: timeit np.einsum('im,im->i',np.einsum('ik,km->im',X,C),X)
    1 loops, best of 3: 1.45 s per loop
    

    The times for these composite calculations are simply sums of the corresponding pieces.

提交回复
热议问题