nearest neighbour search kdTree

前端 未结 4 833
礼貌的吻别
礼貌的吻别 2021-01-20 17:51

To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too la

4条回答
  •  Happy的楠姐
    2021-01-20 18:34

    The sklearn should be the best. I wrote the below some time back ,where I needed custom distance. (I guess sklearn does not support custom distance fn 'KD tree' with custom distance metric . Adding for reference

    Adapted from my gist for 2D https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8

    # From https://gist.github.com/alexcpn/1f187f2114976e748f4d3ad38dea17e8
    # Author alex punnen
    from collections import namedtuple
    from operator import itemgetter
    import numpy as np
    
    def find_nearest_neighbour(node,point,distance_fn,current_axis):
        # Algorith to find nearest neighbour in a KD Tree;the KD tree has done a spatial sort
        # of the given co-ordinates, such that to the left of the root lies co-ordinates nearest to the x-axis
        # and to the right of the root ,lies the co-ordinates farthest from the x axis
        # On the y axis split on the left of the parent/root node lies co-ordinates nearest to the y-axis and to
        # the right of the root, lies the co-ordinates farthest from the y axis
        # to find the nearest neightbour, from the root, you first check left and right node; if distance is closer
        # to the right node,then the entire left node can be discarded from search, because of the spatial split
        # and that node becomes the root node. This process is continued recursively till the nearest is found
        # param:node: The current node
        # param: point: The point to which the nearest neighbour is to be found
        # param: distance_fn: to calculate the nearest neighbour
        # param: current_axis: here assuming only two dimenstion and current axis will be either x or y , 0 or 1
    
        if node is None:
            return None,None
        current_closest_node = node
        closest_known_distance = distance_fn(node.cell[0],node.cell[1],point[0],point[1])
        print closest_known_distance,node.cell
    
        x = (node.cell[0],node.cell[1])
        y = point
    
        new_node = None
        new_closest_distance = None
        if x[current_axis] > y[current_axis]:
            new_node,new_closest_distance= find_nearest_neighbour(node.left_branch,point,distance_fn,
                                                              (current_axis+1) %2)
        else:
            new_node,new_closest_distance = find_nearest_neighbour(node.right_branch,point,distance_fn,
                                                               (current_axis+1) %2) 
    
        if  new_closest_distance and new_closest_distance < closest_known_distance:
            print 'Reset closest node to ',new_node.cell
            closest_known_distance = new_closest_distance
            current_closest_node = new_node
    
        return current_closest_node,closest_known_distance
    
    
    class Node(namedtuple('Node','cell, left_branch, right_branch')):
        # This Class is taken from wikipedia code snippet for  KD tree
        pass
    
    def create_kdtree(cell_list,current_axis,no_of_axis):
        # Creates a KD Tree recursively following the snippet from wikipedia for KD tree
        # but making it generic for any number of axis and changes in data strucure
        if not cell_list:
            return
        # get the cell as a tuple list this is for 2 dimensions
        k= [(cell[0],cell[1])  for cell  in cell_list]
        # say for three dimension
        # k= [(cell[0],cell[1],cell[2])  for cell  in cell_list]
        k.sort(key=itemgetter(current_axis)) # sort on the current axis
        median = len(k) // 2 # get the median of the list
        axis = (current_axis + 1) % no_of_axis # cycle the axis
        return Node(k[median], # recurse 
                    create_kdtree(k[:median],axis,no_of_axis),
                    create_kdtree(k[median+1:],axis,no_of_axis))
    
    def eucleaden_dist(x1,y1,x2,y2):
        a= np.array([x1,y1])
        b= np.array([x2,y2])
        dist = np.linalg.norm(a-b)
        return dist
    
    
    np.random.seed(0)
    #cell_list = np.random.random((2, 2))
    #cell_list = cell_list.tolist()
    cell_list = [[2,2],[4,8],[10,2]]
    print(cell_list)
    tree = create_kdtree(cell_list,0,2)
    
    node,distance = find_nearest_neighbour(tree,(1, 1),eucleaden_dist,0)
    print 'Nearest Neighbour=',node.cell,distance
    
    node,distance = find_nearest_neighbour(tree,(8, 1),eucleaden_dist,0)
    print 'Nearest Neighbour=',node.cell,distance
    

提交回复
热议问题