Understanding NumPy's einsum

后端 未结 6 706
死守一世寂寞
死守一世寂寞 2020-11-22 14:36

I\'m struggling to understand exactly how einsum works. I\'ve looked at the documentation and a few examples, but it\'s not seeming to stick.

Here\'s an

6条回答
  •  旧时难觅i
    2020-11-22 15:24

    Grasping the idea of numpy.einsum() is very easy if you understand it intuitively. As an example, let's start with a simple description involving matrix multiplication.


    To use numpy.einsum(), all you have to do is to pass the so-called subscripts string as an argument, followed by your input arrays.

    Let's say you have two 2D arrays, A and B, and you want to do matrix multiplication. So, you do:

    np.einsum("ij, jk -> ik", A, B)
    

    Here the subscript string ij corresponds to array A while the subscript string jk corresponds to array B. Also, the most important thing to note here is that the number of characters in each subscript string must match the dimensions of the array. (i.e. two chars for 2D arrays, three chars for 3D arrays, and so on.) And if you repeat the chars between subscript strings (j in our case), then that means you want the einsum to happen along those dimensions. Thus, they will be sum-reduced. (i.e. that dimension will be gone)

    The subscript string after this ->, will be our resultant array. If you leave it empty, then everything will be summed and a scalar value is returned as result. Else the resultant array will have dimensions according to the subscript string. In our example, it'll be ik. This is intuitive because we know that for matrix multiplication the number of columns in array A has to match the number of rows in array B which is what is happening here (i.e. we encode this knowledge by repeating the char j in the subscript string)


    Here are some more examples illustrating the use/power of np.einsum() in implementing some common tensor or nd-array operations, succinctly.

    Inputs

    # a vector
    In [197]: vec
    Out[197]: array([0, 1, 2, 3])
    
    # an array
    In [198]: A
    Out[198]: 
    array([[11, 12, 13, 14],
           [21, 22, 23, 24],
           [31, 32, 33, 34],
           [41, 42, 43, 44]])
    
    # another array
    In [199]: B
    Out[199]: 
    array([[1, 1, 1, 1],
           [2, 2, 2, 2],
           [3, 3, 3, 3],
           [4, 4, 4, 4]])
    

    1) Matrix multiplication (similar to np.matmul(arr1, arr2))

    In [200]: np.einsum("ij, jk -> ik", A, B)
    Out[200]: 
    array([[130, 130, 130, 130],
           [230, 230, 230, 230],
           [330, 330, 330, 330],
           [430, 430, 430, 430]])
    

    2) Extract elements along the main-diagonal (similar to np.diag(arr))

    In [202]: np.einsum("ii -> i", A)
    Out[202]: array([11, 22, 33, 44])
    

    3) Hadamard product (i.e. element-wise product of two arrays) (similar to arr1 * arr2)

    In [203]: np.einsum("ij, ij -> ij", A, B)
    Out[203]: 
    array([[ 11,  12,  13,  14],
           [ 42,  44,  46,  48],
           [ 93,  96,  99, 102],
           [164, 168, 172, 176]])
    

    4) Element-wise squaring (similar to np.square(arr) or arr ** 2)

    In [210]: np.einsum("ij, ij -> ij", B, B)
    Out[210]: 
    array([[ 1,  1,  1,  1],
           [ 4,  4,  4,  4],
           [ 9,  9,  9,  9],
           [16, 16, 16, 16]])
    

    5) Trace (i.e. sum of main-diagonal elements) (similar to np.trace(arr))

    In [217]: np.einsum("ii -> ", A)
    Out[217]: 110
    

    6) Matrix transpose (similar to np.transpose(arr))

    In [221]: np.einsum("ij -> ji", A)
    Out[221]: 
    array([[11, 21, 31, 41],
           [12, 22, 32, 42],
           [13, 23, 33, 43],
           [14, 24, 34, 44]])
    

    7) Outer Product (of vectors) (similar to np.outer(vec1, vec2))

    In [255]: np.einsum("i, j -> ij", vec, vec)
    Out[255]: 
    array([[0, 0, 0, 0],
           [0, 1, 2, 3],
           [0, 2, 4, 6],
           [0, 3, 6, 9]])
    

    8) Inner Product (of vectors) (similar to np.inner(vec1, vec2))

    In [256]: np.einsum("i, i -> ", vec, vec)
    Out[256]: 14
    

    9) Sum along axis 0 (similar to np.sum(arr, axis=0))

    In [260]: np.einsum("ij -> j", B)
    Out[260]: array([10, 10, 10, 10])
    

    10) Sum along axis 1 (similar to np.sum(arr, axis=1))

    In [261]: np.einsum("ij -> i", B)
    Out[261]: array([ 4,  8, 12, 16])
    

    11) Batch Matrix Multiplication

    In [287]: BM = np.stack((A, B), axis=0)
    
    In [288]: BM
    Out[288]: 
    array([[[11, 12, 13, 14],
            [21, 22, 23, 24],
            [31, 32, 33, 34],
            [41, 42, 43, 44]],
    
           [[ 1,  1,  1,  1],
            [ 2,  2,  2,  2],
            [ 3,  3,  3,  3],
            [ 4,  4,  4,  4]]])
    
    In [289]: BM.shape
    Out[289]: (2, 4, 4)
    
    # batch matrix multiply using einsum
    In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)
    
    In [293]: BMM
    Out[293]: 
    array([[[1350, 1400, 1450, 1500],
            [2390, 2480, 2570, 2660],
            [3430, 3560, 3690, 3820],
            [4470, 4640, 4810, 4980]],
    
           [[  10,   10,   10,   10],
            [  20,   20,   20,   20],
            [  30,   30,   30,   30],
            [  40,   40,   40,   40]]])
    
    In [294]: BMM.shape
    Out[294]: (2, 4, 4)
    

    12) Sum along axis 2 (similar to np.sum(arr, axis=2))

    In [330]: np.einsum("ijk -> ij", BM)
    Out[330]: 
    array([[ 50,  90, 130, 170],
           [  4,   8,  12,  16]])
    

    13) Sum all the elements in array (similar to np.sum(arr))

    In [335]: np.einsum("ijk -> ", BM)
    Out[335]: 480
    

    14) Sum over multiple axes (i.e. marginalization)
    (similar to np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

    # 8D array
    In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))
    
    # marginalize out axis 5 (i.e. "n" here)
    In [363]: esum = np.einsum("ijklmnop -> n", R)
    
    # marginalize out axis 5 (i.e. sum over rest of the axes)
    In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))
    
    In [365]: np.allclose(esum, nsum)
    Out[365]: True
    

    15) Double Dot Products (similar to np.sum(hadamard-product) cf. 3)

    In [772]: A
    Out[772]: 
    array([[1, 2, 3],
           [4, 2, 2],
           [2, 3, 4]])
    
    In [773]: B
    Out[773]: 
    array([[1, 4, 7],
           [2, 5, 8],
           [3, 6, 9]])
    
    In [774]: np.einsum("ij, ij -> ", A, B)
    Out[774]: 124
    

    16) 2D and 3D array multiplication

    Such a multiplication could be very useful when solving linear system of equations (Ax = b) where you want to verify the result.

    # inputs
    In [115]: A = np.random.rand(3,3)
    In [116]: b = np.random.rand(3, 4, 5)
    
    # solve for x
    In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)
    
    # 2D and 3D array multiplication :)
    In [118]: Ax = np.einsum('ij, jkl', A, x)
    
    # indeed the same!
    In [119]: np.allclose(Ax, b)
    Out[119]: True
    

    On the contrary, if one has to use np.matmul() for this verification, we have to do couple of reshape operations to achieve the same result like:

    # reshape 3D array `x` to 2D, perform matmul
    # then reshape the resultant array to 3D
    In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)
    
    # indeed correct!
    In [124]: np.allclose(Ax, Ax_matmul)
    Out[124]: True
    

    Bonus: Read more math here : Einstein-Summation and definitely here: Tensor-Notation

提交回复
热议问题