Use numpy.tensordot to replace a nested loop

后端 未结 2 706
爱一瞬间的悲伤
爱一瞬间的悲伤 2020-12-03 23:26

I have a piece of code, but I want to pull up the performance. My code is:

lis = []
for i in range(6):
    for j in range(6):
        for k in range(6):
             


        
2条回答
  •  孤城傲影
    2020-12-04 00:24

    Test setup:

    In [274]: lis = np.zeros((6,6),int)
    In [275]: matrix1 = np.arange(36).reshape(6,6)
    In [276]: matrix2 = np.arange(36*36).reshape(6,6,6,6)
    In [277]: for i in range(6):
         ...:     for j in range(6):
         ...:         for k in range(6):
         ...:             for l in range(6):
         ...:                 lis[i,j] += matrix1[k,l] * (2 * matrix2[i,j,k,l] - mat
         ...: rix2[i,k,j,l])
         ...:                 
    In [278]: lis
    Out[278]: 
    array([[-51240,  -9660,  31920,  73500, 115080, 156660],
           [ 84840, 126420, 168000, 209580, 251160, 292740],
           [220920, 262500, 304080, 345660, 387240, 428820],
           [357000, 398580, 440160, 481740, 523320, 564900],
           [493080, 534660, 576240, 617820, 659400, 700980],
           [629160, 670740, 712320, 753900, 795480, 837060]])
    

    right?

    I'm not sure that tensordot is the right tool; at least may not be the simplest. It certainly can't handle the matrix2 difference.

    Let's start with an obvious substitution:

    In [279]: matrix3 = 2*matrix2-matrix2.transpose(0,2,1,3)
    In [280]: lis = np.zeros((6,6),int)
    In [281]: for i in range(6):
         ...:     for j in range(6):
         ...:         for k in range(6):
         ...:             for l in range(6):
         ...:                 lis[i,j] += matrix1[k,l] * matrix3[i,j,k,l]
    

    tests ok - same lis.

    Now it is easy to express this with einsum - just replicate the indices

    In [284]: np.einsum('kl,ijkl->ij', matrix1, matrix3)
    Out[284]: 
    array([[-51240,  -9660,  31920,  73500, 115080, 156660],
           [ 84840, 126420, 168000, 209580, 251160, 292740],
           [220920, 262500, 304080, 345660, 387240, 428820],
           [357000, 398580, 440160, 481740, 523320, 564900],
           [493080, 534660, 576240, 617820, 659400, 700980],
           [629160, 670740, 712320, 753900, 795480, 837060]])
    

    elementwise product plus summation on two axes also works; and an equivalent tensordot (specifying which axes to sum over)

    (matrix1*matrix3).sum(axis=(2,3))
    np.tensordot(matrix1, matrix3, [[0,1],[2,3]])
    

提交回复
热议问题