Speeding up distance matrix computation with Numpy and Cython

后端 未结 1 2096
萌比男神i
萌比男神i 2020-12-30 11:17

Consider a numpy array A of dimensionality NxM. The goal is to compute Euclidean distance matrix D, where each element D[i,j] is Eucledean distance between rows i and j. Wha

1条回答
  •  既然无缘
    2020-12-30 12:03

    The key thing with Cython is to avoid using Python objects and function calls as much as possible, including vectorized operations on numpy arrays. This usually means writing out all of the loops by hand and operating on single array elements at a time.

    There's a very useful tutorial here that covers the process of converting numpy code to Cython and optimizing it.

    Here's a quick stab at a more optimized Cython version of your distance function:

    import numpy as np
    cimport numpy as np
    cimport cython
    
    # don't use np.sqrt - the sqrt function from the C standard library is much
    # faster
    from libc.math cimport sqrt
    
    # disable checks that ensure that array indices don't go out of bounds. this is
    # faster, but you'll get a segfault if you mess up your indexing.
    @cython.boundscheck(False)
    # this disables 'wraparound' indexing from the end of the array using negative
    # indices.
    @cython.wraparound(False)
    def dist(double [:, :] A):
    
        # declare C types for as many of our variables as possible. note that we
        # don't necessarily need to assign a value to them at declaration time.
        cdef:
            # Py_ssize_t is just a special platform-specific type for indices
            Py_ssize_t nrow = A.shape[0]
            Py_ssize_t ncol = A.shape[1]
            Py_ssize_t ii, jj, kk
    
            # this line is particularly expensive, since creating a numpy array
            # involves unavoidable Python API overhead
            np.ndarray[np.float64_t, ndim=2] D = np.zeros((nrow, nrow), np.double)
    
            double tmpss, diff
    
        # another advantage of using Cython rather than broadcasting is that we can
        # exploit the symmetry of D by only looping over its upper triangle
        for ii in range(nrow):
            for jj in range(ii + 1, nrow):
                # we use tmpss to accumulate the SSD over each pair of rows
                tmpss = 0
                for kk in range(ncol):
                    diff = A[ii, kk] - A[jj, kk]
                    tmpss += diff * diff
                tmpss = sqrt(tmpss)
                D[ii, jj] = tmpss
                D[jj, ii] = tmpss  # because D is symmetric
    
        return D
    

    I saved this in a file called fastdist.pyx. We can use pyximport to simplify the build process:

    import pyximport
    pyximport.install()
    import fastdist
    import numpy as np
    
    A = np.random.randn(100, 200)
    
    D1 = np.sqrt(np.square(A[np.newaxis,:,:]-A[:,np.newaxis,:]).sum(2))
    D2 = fastdist.dist(A)
    
    print np.allclose(D1, D2)
    # True
    

    So it works, at least. Let's do some benchmarking using the %timeit magic:

    %timeit np.sqrt(np.square(A[np.newaxis,:,:]-A[:,np.newaxis,:]).sum(2))
    # 100 loops, best of 3: 10.6 ms per loop
    
    %timeit fastdist.dist(A)
    # 100 loops, best of 3: 1.21 ms per loop
    

    A ~9x speed-up is nice, but not really a game-changer. As you said, though, the big problem with the broadcasting approach is the memory requirements of constructing the intermediate array.

    A2 = np.random.randn(1000, 2000)
    %timeit fastdist.dist(A2)
    # 1 loops, best of 3: 1.36 s per loop
    

    I wouldn't recommend trying that using broadcasting...

    Another thing we could do is parallelize this over the outermost loop, using the prange function:

    from cython.parallel cimport prange
    
    ...
    
    for ii in prange(nrow, nogil=True, schedule='guided'):
    ...
    

    In order to compile the parallel version you'll need to tell the compiler to enable OpenMP. I haven't figured out how to do this using pyximport, but if you're using gcc you could compile it manually like this:

    $ cython fastdist.pyx
    $ gcc -shared -pthread -fPIC -fwrapv -fopenmp -O3 \
       -Wall -fno-strict-aliasing  -I/usr/include/python2.7 -o fastdist.so fastdist.c
    

    With parallelism, using 8 threads:

    %timeit D2 = fastdist.dist_parallel(A2)
    1 loops, best of 3: 509 ms per loop
    

    0 讨论(0)
提交回复
热议问题