Fast tensor rotation with NumPy

后端 未结 7 1204
情话喂你
情话喂你 2020-12-04 17:35

At the heart of an application (written in Python and using NumPy) I need to rotate a 4th order tensor. Actually, I need to rotate a lot of tensors many times and this is my

7条回答
  •  粉色の甜心
    2020-12-04 18:21

    Prospective Approach and solution code

    For memory efficiency and thereafter performance efficiency, we could use tensor matrix-multiplication in steps.

    To illustrate the steps involved, let's use the simplest of the solutions with np.einsum by @pv. -

    np.einsum('ai,bj,ck,dl,abcd->ijkl', g, g, g, g, T)
    

    As seen, we are losing the first dimension from g with tensor-multiplication between its four variants and T.

    Let's do those sum-reductions for tensor matrix multiplications in steps. Let's start off with the first variant of g and T :

    p1 = np.einsum('abcd, ai->bcdi', T, g)
    

    Thus, we end up with a tensor of dimensions as string notation : bcdi. The next steps would involve sum-reducing this tensor against the rest of the three g variants as used in the original einsum implmentation. Hence, the next reduction would be -

    p2 = np.einsum('bcdi, bj->cdij', p1, g)
    

    As seen, we have lost the first two dimensions with the string notations : a, b. We continue it for two more steps to get rid of c and d too and would be left with ijkl as the final output, like so -

    p3 = np.einsum('cdij, ck->dijk', p2, g)
    
    p4 = np.einsum('dijk, dl->ijkl', p3, g)
    

    Now, we could use np.tensordot for these sum-reductions, which would be much more efficient.

    Final implementation

    Thus, porting over to np.tensordot, we would have the final implementation like so -

    p1 = np.tensordot(T,g,axes=((0),(0)))
    p2 = np.tensordot(p1,g,axes=((0),(0)))
    p3 = np.tensordot(p2,g,axes=((0),(0)))
    out = np.tensordot(p3,g,axes=((0),(0)))
    

    Runtime test

    Let's test out all the NumPy based approaches posted across other posts to solve the problem on performance.

    Approaches as functions -

    def rotT_Philipp(T, g):  # @Philipp's soln
        gg = np.outer(g, g)
        gggg = np.outer(gg, gg).reshape(4 * g.shape)
        axes = ((0, 2, 4, 6), (0, 1, 2, 3))
        return np.tensordot(gggg, T, axes)
    
    def rotT_Sven(T, g):    # @Sven Marnach's soln
        Tprime = T
        for i in range(4):
            slices = [None] * 4
            slices[i] = slice(None)
            slices *= 2
            Tprime = g[slices].T * Tprime
        return Tprime.sum(-1).sum(-1).sum(-1).sum(-1)    
    
    def rotT_pv(T, g):     # @pv.'s soln
        return np.einsum('ai,bj,ck,dl,abcd->ijkl', g, g, g, g, T)
    
    def rotT_Divakar(T,g): # Posted in this post
        p1 = np.tensordot(T,g,axes=((0),(0)))
        p2 = np.tensordot(p1,g,axes=((0),(0)))
        p3 = np.tensordot(p2,g,axes=((0),(0)))
        p4 = np.tensordot(p3,g,axes=((0),(0)))
        return p4
    

    Timings with the original dataset sizes -

    In [304]: # Setup inputs 
         ...: T = np.random.rand(3,3,3,3)
         ...: g = np.random.rand(3,3)
         ...: 
    
    In [305]: %timeit rotT(T, g)
         ...: %timeit rotT_pv(T, g)
         ...: %timeit rotT_Sven(T, g)
         ...: %timeit rotT_Philipp(T, g)
         ...: %timeit rotT_Divakar(T, g)
         ...: 
    100 loops, best of 3: 6.51 ms per loop
    1000 loops, best of 3: 247 µs per loop
    10000 loops, best of 3: 137 µs per loop
    10000 loops, best of 3: 41.6 µs per loop
    10000 loops, best of 3: 28.3 µs per loop
    
    In [306]: 6510.0/28.3 # Speedup with the proposed soln over original code
    Out[306]: 230.03533568904592
    

    As discussed at the start of this post, we are trying to achieve memory efficiency and hence performance boost with it. Let's test that out as we increase the dataset sizes -

    In [307]: # Setup inputs 
         ...: T = np.random.rand(5,5,5,5)
         ...: g = np.random.rand(5,5)
         ...: 
    
    In [308]: %timeit rotT(T, g)
         ...: %timeit rotT_pv(T, g)
         ...: %timeit rotT_Sven(T, g)
         ...: %timeit rotT_Philipp(T, g)
         ...: %timeit rotT_Divakar(T, g)
         ...: 
    100 loops, best of 3: 6.54 ms per loop
    100 loops, best of 3: 7.17 ms per loop
    100 loops, best of 3: 2.7 ms per loop
    1000 loops, best of 3: 1.47 ms per loop
    10000 loops, best of 3: 39.9 µs per loop
    

提交回复
热议问题