What is the most efficient way to compute the square euclidean distance between N samples and clusters centroids?

淺唱寂寞╮ 提交于 2019-11-29 15:48:18

This question is frequently asked in combination with nereast-neighbour search. If this is the case have a look at a kdtree approach. This will be far more efficient than calculating euclidian distances, both in memory consumption and performance.

If this is not the case, here are some possible approaches. The first two are from an answer of Divakar. The third one uses Numba for jit-compilation. The main difference to the first two versions is he avoidance of temporary arrays.

Three aproaches to calculate euclidian distances

import numpy as np
import numba as nb

# @Paul Panzer
#https://stackoverflow.com/a/42994680/4045774
def outer_sum_dot_app(A,B):
    return np.add.outer((A*A).sum(axis=-1), (B*B).sum(axis=-1)) - 2*np.dot(A,B.T)

# @Divakar
#https://stackoverflow.com/a/42994680/4045774
def outer_einsum_dot_app(A,B):
    return np.einsum('ij,ij->i',A,A)[:,None] + np.einsum('ij,ij->i',B,B) - 2*np.dot(A,B.T)

@nb.njit()
def calc_dist(A,B,sqrt=False):
  dist=np.dot(A,B.T)

  TMP_A=np.empty(A.shape[0],dtype=A.dtype)
  for i in range(A.shape[0]):
    sum=0.
    for j in range(A.shape[1]):
      sum+=A[i,j]**2
    TMP_A[i]=sum

  TMP_B=np.empty(B.shape[0],dtype=A.dtype)
  for i in range(B.shape[0]):
    sum=0.
    for j in range(B.shape[1]):
      sum+=B[i,j]**2
    TMP_B[i]=sum

  if sqrt==True:
    for i in range(A.shape[0]):
      for j in range(B.shape[0]):
        dist[i,j]=np.sqrt(-2.*dist[i,j]+TMP_A[i]+TMP_B[j])
  else:
    for i in range(A.shape[0]):
      for j in range(B.shape[0]):
        dist[i,j]=-2.*dist[i,j]+TMP_A[i]+TMP_B[j]
  return dist

Timings

A = np.random.randn(10000,3)
B = np.random.randn(10000,3)

#calc_dist:                      360ms first call excluded due to compilation overhead
#outer_einsum_dot_app (Divakar): 1150ms
#outer_sum_dot_app (Paul Panzer):1590ms
#dist_2:                         1840ms

A = np.random.randn(1000,100)
B = np.random.randn(1000,100)

#calc_dist:                      4.3  ms first call excluded due to compilation overhead
#outer_einsum_dot_app (Divakar): 13.12ms
#outer_sum_dot_app (Paul Panzer):13.2 ms
#dist_2:                         21.3ms
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!