Understanding NumPy's einsum

后端 未结 6 704
死守一世寂寞
死守一世寂寞 2020-11-22 14:36

I\'m struggling to understand exactly how einsum works. I\'ve looked at the documentation and a few examples, but it\'s not seeming to stick.

Here\'s an

6条回答
  •  遥遥无期
    2020-11-22 15:20

    Lets make 2 arrays, with different, but compatible dimensions to highlight their interplay

    In [43]: A=np.arange(6).reshape(2,3)
    Out[43]: 
    array([[0, 1, 2],
           [3, 4, 5]])
    
    
    In [44]: B=np.arange(12).reshape(3,4)
    Out[44]: 
    array([[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11]])
    

    Your calculation, takes a 'dot' (sum of products) of a (2,3) with a (3,4) to produce a (4,2) array. i is the 1st dim of A, the last of C; k the last of B, 1st of C. j is 'consumed' by the summation.

    In [45]: C=np.einsum('ij,jk->ki',A,B)
    Out[45]: 
    array([[20, 56],
           [23, 68],
           [26, 80],
           [29, 92]])
    

    This is the same as np.dot(A,B).T - it's the final output that's transposed.

    To see more of what happens to j, change the C subscripts to ijk:

    In [46]: np.einsum('ij,jk->ijk',A,B)
    Out[46]: 
    array([[[ 0,  0,  0,  0],
            [ 4,  5,  6,  7],
            [16, 18, 20, 22]],
    
           [[ 0,  3,  6,  9],
            [16, 20, 24, 28],
            [40, 45, 50, 55]]])
    

    This can also be produced with:

    A[:,:,None]*B[None,:,:]
    

    That is, add a k dimension to the end of A, and an i to the front of B, resulting in a (2,3,4) array.

    0 + 4 + 16 = 20, 9 + 28 + 55 = 92, etc; Sum on j and transpose to get the earlier result:

    np.sum(A[:,:,None] * B[None,:,:], axis=1).T
    
    # C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
    

提交回复
热议问题