Python: Running nested loop, 2D moving window, in Parallel

前端 未结 2 497
灰色年华
灰色年华 2021-01-06 04:31

I work with topographic data. For one particular problem, I have written a function in Python which uses a moving window of a particular size to zip through a matrix (grid o

2条回答
  •  [愿得一人]
    2021-01-06 05:22

    A solution using Numba

    In some cases this is very easy to do, if all functions which you use are supported. In your code win = signal.detrend(win, type = 'linear') is the part you have to implement in Numba, because this function isn't supported.

    Implementing detrend in Numba

    If you look at the source-code of detrend, and extract the relevant parts for your problem, it may look like this:

    @nb.njit()
    def detrend(w):
        Npts=w.shape[0]
        A=np.empty((Npts,2),dtype=w.dtype)
        for i in range(Npts):
            A[i,0]=1.*(i+1) / Npts
            A[i,1]=1.
    
        coef, resids, rank, s = np.linalg.lstsq(A, w.T)
        out=w.T- np.dot(A, coef)
        return out.T
    

    I also implemented a faster solution for np.max(np.isnan(win)) == 1

    @nb.njit()
    def isnan(win):
        for i in range(win.shape[0]):
            for j in range(win.shape[1]):
                if np.isnan(win[i,j]):
                    return True
        return False
    

    Main function

    As I used Numba here, the parallelization is very simple, just a prange on the outer loop and

    import numpy as np
    import numba as nb
    
    @nb.njit(parallel=True)
    def RMSH_det_nb(DEM, w):
        [nrows, ncols] = np.shape(DEM)
    
        #create an empty array to store result
        rms = DEM*np.nan
    
        for i in nb.prange(w+1,nrows-w):
            for j in range(w+1,ncols-w):
                win = DEM[i-w:i+w-1,j-w:j+w-1]
    
                if isnan(win):
                    rms[i,j] = np.nan
                else:
                    win = detrend(win)
                    z = win.flatten()
                    nz = z.size
                    rootms = np.sqrt(1 / (nz - 1) * np.sum((z-np.mean(z))**2))
                    rms[i,j] = rootms
    
        return rms
    

    Timings (small example)

    w = 10
    DEM=np.random.rand(100, 100).astype(np.float32)
    
    res1=RMSH_det(DEM, w)
    res2=RMSH_det_nb(DEM, w)
    print(np.allclose(res1,res2,equal_nan=True))
    #True
    
    %timeit res1=RMSH_det(DEM, w)
    #1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit res2=RMSH_det_nb(DEM, w) #approx. 55 times faster
    #29 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    Timings for larger arrays

    w = 10
    DEM=np.random.rand(1355, 1165).astype(np.float32)
    %timeit res2=RMSH_det_nb(DEM, w)
    #6.63 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    [Edit] Implemenation using normal Equations

    Overdetermined system

    This method has a lower numerical precision. Although this solution is quite a lot faster.

    @nb.njit()
    def isnan(win):
        for i in range(win.shape[0]):
            for j in range(win.shape[1]):
                if win[i,j]==np.nan:
                    return True
        return False
    
    @nb.njit()
    def detrend(w):
        Npts=w.shape[0]
        A=np.empty((Npts,2),dtype=w.dtype)
        for i in range(Npts):
            A[i,0]=1.*(i+1) / Npts
            A[i,1]=1.
    
        coef, resids, rank, s = np.linalg.lstsq(A, w.T)
        out=w.T- np.dot(A, coef)
        return out.T
    
    @nb.njit()
    def detrend_2(w,T1,A):
        T2=np.dot(A.T,w.T)
        coef=np.linalg.solve(T1,T2)
    
        out=w.T- np.dot(A, coef)
    
        return out.T
    
    @nb.njit(parallel=True)
    def RMSH_det_nb_normal_eq(DEM,w):
        [nrows, ncols] = np.shape(DEM)
    
        #create an empty array to store result
        rms = DEM*np.nan
    
        Npts=w*2-1
        A=np.empty((Npts,2),dtype=DEM.dtype)
        for i in range(Npts):
            A[i,0]=1.*(i+1) / Npts
            A[i,1]=1.
    
        T1=np.dot(A.T,A)
    
        nz = Npts**2
        for i in nb.prange(w+1,nrows-w):
            for j in range(w+1,ncols-w):
                win = DEM[i-w:i+w-1,j-w:j+w-1]
    
                if isnan(win):
                    rms[i,j] = np.nan
                else:
                    win = detrend_2(win,T1,A)
                    rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
                    rms[i,j] = rootms
    
        return rms
    

    Timings

    w = 10
    DEM=np.random.rand(100, 100).astype(np.float32)
    
    res1=RMSH_det(DEM, w)
    res2=RMSH_det_nb(DEM, w)
    print(np.allclose(res1,res2,equal_nan=True))
    #True
    
    %timeit res1=RMSH_det(DEM, w)
    #1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit res2=RMSH_det_nb_normal_eq(DEM,w)
    #7.97 ms ± 89.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    Optimized solution using normal equations

    Temporary arrays are reused to avoid costly memory allocations and a custom implementation for matrix multiplication is used. This is only recommendable for very small matrices, in most other cases np.dot (sgeemm) will be a lot faster.

    @nb.njit()
    def matmult_2(A,B,out):
        for j in range(B.shape[1]):
            acc1=nb.float32(0)
            acc2=nb.float32(0)
            for k in range(B.shape[0]):
                acc1+=A[0,k]*B[k,j]
                acc2+=A[1,k]*B[k,j]
            out[0,j]=acc1
            out[1,j]=acc2
        return out
    
    @nb.njit(fastmath=True)
    def matmult_mod(A,B,w,out):
        for j in range(B.shape[1]):
            for i in range(A.shape[0]):
                acc=nb.float32(0)
                acc+=A[i,0]*B[0,j]+A[i,1]*B[1,j]
                out[j,i]=acc-w[j,i]
        return out
    
    @nb.njit()
    def detrend_2_opt(w,T1,A,Tempvar_1,Tempvar_2):
        T2=matmult_2(A.T,w.T,Tempvar_1)
        coef=np.linalg.solve(T1,T2)
        return matmult_mod(A, coef,w,Tempvar_2)
    
    @nb.njit(parallel=True)
    def RMSH_det_nb_normal_eq_opt(DEM,w):
        [nrows, ncols] = np.shape(DEM)
    
        #create an empty array to store result
        rms = DEM*np.nan
    
        Npts=w*2-1
        A=np.empty((Npts,2),dtype=DEM.dtype)
        for i in range(Npts):
            A[i,0]=1.*(i+1) / Npts
            A[i,1]=1.
    
        T1=np.dot(A.T,A)
    
        nz = Npts**2
        for i in nb.prange(w+1,nrows-w):
            Tempvar_1=np.empty((2,Npts),dtype=DEM.dtype)
            Tempvar_2=np.empty((Npts,Npts),dtype=DEM.dtype)
            for j in range(w+1,ncols-w):
                win = DEM[i-w:i+w-1,j-w:j+w-1]
    
                if isnan(win):
                    rms[i,j] = np.nan
                else:
                    win = detrend_2_opt(win,T1,A,Tempvar_1,Tempvar_2)
                    rootms = np.sqrt(1 / (nz - 1) * np.sum((win-np.mean(win))**2))
                    rms[i,j] = rootms
    
        return rms
    

    Timings

    w = 10
    DEM=np.random.rand(100, 100).astype(np.float32)
    
    res1=RMSH_det(DEM, w)
    res2=RMSH_det_nb_normal_eq_opt(DEM, w)
    print(np.allclose(res1,res2,equal_nan=True))
    #True
    
    %timeit res1=RMSH_det(DEM, w)
    #1.59 s ± 72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit res2=RMSH_det_nb_normal_eq_opt(DEM,w)
    #4.66 ms ± 87.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    Timings for isnan

    This function is a completely other implementation. It is much faster if a NaN is quite at the beginning of the array, but anyway even if not there is some speedup. I benchmarked it with small arrays (approx. window size) and a large size suggested by @user3666197.

    case_1=np.full((20,20),np.nan)
    case_2=np.full((20,20),0.)
    case_2[10,10]=np.nan
    case_3=np.full((20,20),0.)
    
    case_4 = np.full( ( int( 1E4 ), int( 1E4 ) ),np.nan)
    case_5 = np.ones( ( int( 1E4 ), int( 1E4 ) ) )
    
    %timeit np.any(np.isnan(case_1))
    %timeit np.any(np.isnan(case_2))
    %timeit np.any(np.isnan(case_3))
    %timeit np.any(np.isnan(case_4))
    %timeit np.any(np.isnan(case_5))
    #2.75 µs ± 73.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #2.75 µs ± 46.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #2.76 µs ± 32.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    #81.3 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    #86.7 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    %timeit isnan(case_1)
    %timeit isnan(case_2)
    %timeit isnan(case_3)
    %timeit isnan(case_4)
    %timeit isnan(case_5)
    #244 ns ± 5.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    #357 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    #475 ns ± 9.28 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    #235 ns ± 0.933 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    #58.8 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

提交回复
热议问题