Understanding NumPy's einsum

后端 未结 6 724
死守一世寂寞
死守一世寂寞 2020-11-22 14:36

I\'m struggling to understand exactly how einsum works. I\'ve looked at the documentation and a few examples, but it\'s not seeming to stick.

Here\'s an

6条回答
  •  无人共我
    2020-11-22 15:14

    I found NumPy: The tricks of the trade (Part II) instructive

    We use -> to indicate the order of the output array. So think of 'ij, i->j' as having left hand side (LHS) and right hand side (RHS). Any repetition of labels on the LHS computes the product element wise and then sums over. By changing the label on the RHS (output) side, we can define the axis in which we want to proceed with respect to the input array, i.e. summation along axis 0, 1 and so on.

    import numpy as np
    
    >>> a
    array([[1, 1, 1],
           [2, 2, 2],
           [3, 3, 3]])
    >>> b
    array([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]])
    >>> d = np.einsum('ij, jk->ki', a, b)
    

    Notice there are three axes, i, j, k, and that j is repeated (on the left-hand-side). i,j represent rows and columns for a. j,k for b.

    In order to calculate the product and align the j axis we need to add an axis to a. (b will be broadcast along(?) the first axis)

    a[i, j, k]
       b[j, k]
    
    >>> c = a[:,:,np.newaxis] * b
    >>> c
    array([[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8]],
    
           [[ 0,  2,  4],
            [ 6,  8, 10],
            [12, 14, 16]],
    
           [[ 0,  3,  6],
            [ 9, 12, 15],
            [18, 21, 24]]])
    

    j is absent from the right-hand-side so we sum over j which is the second axis of the 3x3x3 array

    >>> c = c.sum(1)
    >>> c
    array([[ 9, 12, 15],
           [18, 24, 30],
           [27, 36, 45]])
    

    Finally, the indices are (alphabetically) reversed on the right-hand-side so we transpose.

    >>> c.T
    array([[ 9, 18, 27],
           [12, 24, 36],
           [15, 30, 45]])
    
    >>> np.einsum('ij, jk->ki', a, b)
    array([[ 9, 18, 27],
           [12, 24, 36],
           [15, 30, 45]])
    >>>
    

提交回复
热议问题