Python, numpy, einsum multiply a stack of matrices

后端 未结 1 1689
遇见更好的自我
遇见更好的自我 2021-01-02 03:17

For performance reasons,

I\'m curious if there is a way to multiply a stack of a stack of matrices. I have a 4-D array (500, 201, 2, 2). Its basi

相关标签:
1条回答
  • 2021-01-02 03:41

    I don't think it's possible to do this efficiently using numpy (the cumprod solution was elegant, though). This is the sort of situation where I would use f2py. It's the simplest way of calling a faster language that I know of and only requires a single extra file.

    fortran.f90:

    subroutine multimul(a, b)
      implicit none
      real(8), intent(in)  :: a(:,:,:,:)
      real(8), intent(out) :: b(size(a,1),size(a,2),size(a,3))
      real(8) :: work(size(a,1),size(a,2))
      integer i, j, k, l, m
      !$omp parallel do private(work,i,j)
      do i = 1, size(b,3)
        b(:,:,i) = a(:,:,i,size(a,4)) 
        do j = size(a,4)-1, 1, -1
          work = matmul(b(:,:,i),a(:,:,i,j))
          b(:,:,i) = work
        end do
      end do
    end subroutine
    

    Compile with f2py -c -m fortran fortran.f90 (or F90FLAGS="-fopenmp" f2py -c -m fortran fortran.f90 -lgomp to enable OpenMP acceleration). Then you would use it in your script as

    import numpy as np, fmuls
    Arr = np.random.standard_normal([500,201,2,2])
    def loopMult(Arr):
      ArrMult = Arr[0]
      for i in range(1,len(Arr)):
        ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
      return ArrMult
    def myeinsum(A1, A2):
      return np.einsum('fij,fjk->fik', A1, A2)
    A1 = loopMult(Arr)
    A2 = reduce(myeinsum, Arr)
    A3 = fmuls.multimul(Arr.T).T
    print np.allclose(A1,A2)
    print np.allclose(A1,A3)
    %timeit loopMult(Arr)
    %timeit reduce(myeinsum, Arr)
    %timeit fmuls.multimul(Arr.T).T
    

    Which outputs

    True
    True
    10 loops, best of 3: 48.4 ms per loop
    10 loops, best of 3: 48.8 ms per loop
    100 loops, best of 3: 5.82 ms per loop
    

    So that's a factor 8 speedup. The reason for all the transposes is that f2py implicitly transposes all the arrays, and we need to transpose them manually to tell it that our fortran code expects things to be transposed. This avoids a copy operation. The cost is that each of our 2x2 matrices are transposed, so to avoid performing the wrong operation we have to loop in reverse.

    Greater speedups than 8 should be possible - I didn't spend any time trying to optimize this.

    0 讨论(0)
提交回复
热议问题