Fast interpolation over 3D array

后端 未结 6 983
梦毁少年i
梦毁少年i 2020-12-13 22:03

I have a 3D array that I need to interpolate over one axis (the last dimension). Let\'s say y.shape = (nx, ny, nz), I want to interpolate in nz for

6条回答
  •  粉色の甜心
    2020-12-13 22:57

    Building on @pv.'s answer, and vectorising the inner loop, the following gives a substantial speedup (EDIT: changed the expensive numpy.tile to using numpy.lib.stride_tricks.as_strided):

    import numpy
    from scipy import interpolate
    
    nx = 30
    ny = 40
    nz = 50
    
    y = numpy.random.randn(nx, ny, nz)
    x = numpy.float64(numpy.arange(0, nz))
    
    # We select some locations in the range [0.1, nz-0.1]
    new_z = numpy.random.random_integers(1, (nz-1)*10, size=(nx, ny))/10.0
    
    # y is a 3D ndarray
    # x is a 1D ndarray with the abcissa values
    # new_z is a 2D array
    
    def original_interpolation():
        result = numpy.empty(y.shape[:-1])
        for i in range(nx):
            for j in range(ny):
                f = interpolate.interp1d(x, y[i, j], axis=-1, kind='linear')
                result[i, j] = f(new_z[i, j])
    
        return result
    
    grid_x, grid_y = numpy.mgrid[0:nx, 0:ny]
    def faster_interpolation():
        flat_new_z = new_z.ravel()
        k = numpy.searchsorted(x, flat_new_z)
        k = k.reshape(nx, ny)
    
        lower_index = [grid_x, grid_y, k-1]
        upper_index = [grid_x, grid_y, k]
    
        tiled_x = numpy.lib.stride_tricks.as_strided(x, shape=(nx, ny, nz), 
            strides=(0, 0, x.itemsize))
    
        z_upper = tiled_x[upper_index]
        z_lower = tiled_x[lower_index]
    
        z_step = z_upper - z_lower
        z_delta = new_z - z_lower
    
        y_lower = y[lower_index]
        result = y_lower + z_delta * (y[upper_index] - y_lower)/z_step
    
        return result
    
    # both should be the same (giving a small difference)
    print numpy.max(
            numpy.abs(original_interpolation() - faster_interpolation()))
    

    That gives the following times on my machine:

    In [8]: timeit foo.original_interpolation()
    10 loops, best of 3: 102 ms per loop
    
    In [9]: timeit foo.faster_interpolation()
    1000 loops, best of 3: 564 us per loop
    

    Going to nx = 300, ny = 300 and nz = 500, gives a 130x speedup:

    In [2]: timeit original_interpolation()
    1 loops, best of 3: 8.27 s per loop
    
    In [3]: timeit faster_interpolation()
    10 loops, best of 3: 60.1 ms per loop
    

    You'd need a write your own algorithm for cubic interpolation, but it shouldn't be so hard.

提交回复
热议问题