Fast weighted euclidean distance between points in arrays

后端 未结 4 1011
清歌不尽
清歌不尽 2020-12-19 14:54

I need to efficiently calculate the euclidean weighted distances for every x,y point in a given array to every other x,y point in another

4条回答
  •  盖世英雄少女心
    2020-12-19 15:08

    @Evan and @Martinis Group are on the right track - to expand on Evan's answer, here's a function that uses broadcasting to quickly calculate the n-dimensional weighted euclidean distance without Python loops:

    import numpy as np
    
    def fast_wdist(A, B, W):
        """
        Compute the weighted euclidean distance between two arrays of points:
    
        D{i,j} = 
        sqrt( ((A{0,i}-B{0,j})/W{0,i})^2 + ... + ((A{k,i}-B{k,j})/W{k,i})^2 )
    
        inputs:
            A is an (k, m) array of coordinates
            B is an (k, n) array of coordinates
            W is an (k, m) array of weights
    
        returns:
            D is an (m, n) array of weighted euclidean distances
        """
    
        # compute the differences and apply the weights in one go using
        # broadcasting jujitsu. the result is (n, k, m)
        wdiff = (A[np.newaxis,...] - B[np.newaxis,...].T) / W[np.newaxis,...]
    
        # square and sum over the second axis, take the sqrt and transpose. the
        # result is an (m, n) array of weighted euclidean distances
        D = np.sqrt((wdiff*wdiff).sum(1)).T
    
        return D
    

    To check that this works OK, we'll compare it to a slower version that uses nested Python loops:

    def slow_wdist(A, B, W):
    
        k,m = A.shape
        _,n = B.shape
        D = np.zeros((m, n))
    
        for ii in xrange(m):
            for jj in xrange(n):
                wdiff = (A[:,ii] - B[:,jj]) / W[:,ii]
                D[ii,jj] = np.sqrt((wdiff**2).sum())
        return D
    

    First, let's make sure that the two functions give the same answer:

    # make some random points and weights
    def setup(k=2, m=100, n=300):
        return np.random.randn(k,m), np.random.randn(k,n),np.random.randn(k,m)
    
    a, b, w = setup()
    d0 = slow_wdist(a, b, w)
    d1 = fast_wdist(a, b, w)
    
    print np.allclose(d0, d1)
    # True
    

    Needless to say, the version that uses broadcasting rather than Python loops is several orders of magnitude faster:

    %%timeit a, b, w = setup()
    slow_wdist(a, b, w)
    # 1 loops, best of 3: 647 ms per loop
    
    %%timeit a, b, w = setup()
    fast_wdist(a, b, w)
    # 1000 loops, best of 3: 620 us per loop
    

提交回复
热议问题