Calculate Distances Between One Point in Matrix From All Other Points

后端 未结 4 2318
借酒劲吻你
借酒劲吻你 2021-02-15 16:50

I am new to Python and I need to implement a clustering algorithm. For that, I will need to calculate distances between the given input data.

Consider the following inpu

4条回答
  •  没有蜡笔的小新
    2021-02-15 16:54

    Here's one approach using SciPy's cdist -

    from scipy.spatial.distance import cdist
    def closest_rows(a):
        # Get euclidean distances as 2D array
        dists = cdist(a, a, 'sqeuclidean')
    
        # Fill diagonals with something greater than all elements as we intend
        # to get argmin indices later on and then index into input array with those
        # indices to get the closest rows
        dists.ravel()[::dists.shape[1]+1] = dists.max()+1
        return a[dists.argmin(1)]
    

    Sample run -

    In [72]: a
    Out[72]: 
    array([[1, 2, 8],
           [7, 4, 2],
           [9, 1, 7],
           [0, 1, 5],
           [6, 4, 3]])
    
    In [73]: closest_rows(a)
    Out[73]: 
    array([[0, 1, 5],
           [6, 4, 3],
           [6, 4, 3],
           [1, 2, 8],
           [7, 4, 2]])
    

    Runtime test

    Other working approach(es) -

    def norm_app(a): # @Psidom's soln
        dist = np.linalg.norm(a - a[:,None], axis=-1); 
        dist[np.arange(dist.shape[0]), np.arange(dist.shape[0])] = np.nan
        return a[np.nanargmin(dist, axis=0)]
    

    Timings with 10,000 points -

    In [79]: a = np.random.randint(0,9,(10000,3))
    
    In [80]: %timeit norm_app(a) # @Psidom's soln
    1 loop, best of 3: 3.83 s per loop
    
    In [81]: %timeit closest_rows(a)
    1 loop, best of 3: 392 ms per loop
    

    Further performance boost

    There's eucl_dist package (disclaimer: I am its author) that contains various methods to compute euclidean distances that are much more efficient than SciPy's cdist, especially for large arrays.

    Thus, making use of it, we would have a more performant one, like so -

    from eucl_dist.cpu_dist import dist
    def closest_rows_v2(a):
        dists = dist(a,a, matmul="gemm", method="ext") 
        dists.ravel()[::dists.shape[1]+1] = dists.max()+1
        return a[dists.argmin(1)]
    

    Timings -

    In [162]: a = np.random.randint(0,9,(10000,3))
    
    In [163]: %timeit closest_rows(a)
    1 loop, best of 3: 394 ms per loop
    
    In [164]: %timeit closest_rows_v2(a)
    1 loop, best of 3: 229 ms per loop
    

提交回复
热议问题