For performance reasons,
I\'m curious if there is a way to multiply a stack of a stack of matrices. I have a 4-D array (500, 201, 2, 2). Its basi
I don't think it's possible to do this efficiently using numpy (the cumprod
solution was elegant, though). This is the sort of situation where I would use f2py
. It's the simplest way of calling a faster language that I know of and only requires a single extra file.
fortran.f90:
subroutine multimul(a, b)
implicit none
real(8), intent(in) :: a(:,:,:,:)
real(8), intent(out) :: b(size(a,1),size(a,2),size(a,3))
real(8) :: work(size(a,1),size(a,2))
integer i, j, k, l, m
!$omp parallel do private(work,i,j)
do i = 1, size(b,3)
b(:,:,i) = a(:,:,i,size(a,4))
do j = size(a,4)-1, 1, -1
work = matmul(b(:,:,i),a(:,:,i,j))
b(:,:,i) = work
end do
end do
end subroutine
Compile with f2py -c -m fortran fortran.f90
(or F90FLAGS="-fopenmp" f2py -c -m fortran fortran.f90 -lgomp
to enable OpenMP acceleration). Then you would use it in your script as
import numpy as np, fmuls
Arr = np.random.standard_normal([500,201,2,2])
def loopMult(Arr):
ArrMult = Arr[0]
for i in range(1,len(Arr)):
ArrMult = np.einsum('fij,fjk->fik', ArrMult, Arr[i])
return ArrMult
def myeinsum(A1, A2):
return np.einsum('fij,fjk->fik', A1, A2)
A1 = loopMult(Arr)
A2 = reduce(myeinsum, Arr)
A3 = fmuls.multimul(Arr.T).T
print np.allclose(A1,A2)
print np.allclose(A1,A3)
%timeit loopMult(Arr)
%timeit reduce(myeinsum, Arr)
%timeit fmuls.multimul(Arr.T).T
Which outputs
True
True
10 loops, best of 3: 48.4 ms per loop
10 loops, best of 3: 48.8 ms per loop
100 loops, best of 3: 5.82 ms per loop
So that's a factor 8 speedup. The reason for all the transposes is that f2py
implicitly transposes all the arrays, and we need to transpose them manually to tell it that our fortran code expects things to be transposed. This avoids a copy operation. The cost is that each of our 2x2 matrices are transposed, so to avoid performing the wrong operation we have to loop in reverse.
Greater speedups than 8 should be possible - I didn't spend any time trying to optimize this.