Fast tensor rotation with NumPy

后端 未结 7 1205
情话喂你
情话喂你 2020-12-04 17:35

At the heart of an application (written in Python and using NumPy) I need to rotate a 4th order tensor. Actually, I need to rotate a lot of tensors many times and this is my

7条回答
  •  眼角桃花
    2020-12-04 18:30

    To use tensordot, compute the outer product of the g tensors:

    def rotT(T, g):
        gg = np.outer(g, g)
        gggg = np.outer(gg, gg).reshape(4 * g.shape)
        axes = ((0, 2, 4, 6), (0, 1, 2, 3))
        return np.tensordot(gggg, T, axes)
    

    On my system, this is around seven times faster than Sven's solution. If the g tensor doesn't change often, you can also cache the gggg tensor. If you do this and turn on some micro-optimizations (inlining the tensordot code, no checks, no generic shapes), you can still make it two times faster:

    def rotT(T, gggg):
        return np.dot(gggg.transpose((1, 3, 5, 7, 0, 2, 4, 6)).reshape((81, 81)),
                      T.reshape(81, 1)).reshape((3, 3, 3, 3))
    

    Results of timeit on my home laptop (500 iterations):

    Your original code: 19.471129179
    Sven's code: 0.718412876129
    My first code: 0.118047952652
    My second code: 0.0690279006958
    

    The numbers on my work machine are:

    Your original code: 9.77922987938
    Sven's code: 0.137110948563
    My first code: 0.0569641590118
    My second code: 0.0308079719543
    

提交回复
热议问题