How to find “nearest neighbors” in a list in Python?

耗尽温柔 提交于 2021-02-07 20:12:53

问题


My objective is the following: given a list of numbers (int. or float) (1D), I need to find N amount of numbers that are closest to
each other in this list. For example:

Given sample_list = [1, 8, 14, 15, 17, 100, 35] and N = 3, the desired function should return a list of [14, 15, 17], because the 3 "nearest neighbors" in this list are 14, 15, and 17.

So far, I have tackled the problem the following way, being forced to make several compromises in my approach (after unsuccessfully searching for a solution online):

I decided to iterate over each number in the list and find its N amount of "nearest neighbors". These "nearest neighbors" are stored in a new list and then compared with other such lists using numpy.diff function. The list with the smallest amount resulted from numpy.diff is the winner.

At the base of my solution is the code provided by this article on geeksforgeeks.org. The idea to use numpy.diff function came from here on stackoverflow.com

My (working) monstrosity can be found below:

import numpy as np
def find_nn_in_list_1D(list_for_search, num_neighbors_desired):

        list_iterator = list_for_search.copy()

        list_mod = list_for_search.copy()

        neighbor_storage_list = []

        neighbor_eval_score = 999_999_999
                
        for number in list_iterator:
                
                neighbor_nil = list_mod[min(range(len(list_mod)), key = lambda i: abs(list_mod[i] - number))]
                
                neighbor_storage_list.append(neighbor_nil)

                list_mod.remove(neighbor_nil)
        
                for num_neighbor in range(num_neighbors_desired):
                
                        neighbor = list_mod[min(range(len(list_mod)), key = lambda i: abs(list_mod[i] - number))]
                                
                        neighbor_storage_list.append(neighbor)
                                
                        list_mod.remove(neighbor)
                
                neighbor_storage_list.sort()

                if neighbor_eval_score > abs(sum(np.diff(neighbor_storage_list))):
                                
                        neighbor_eval_score = abs(sum(np.diff(neighbor_storage_list)))
                        
                        nearest_neighbors_list = neighbor_storage_list.copy()
                                                        
                list_mod = list_for_search.copy()
                        
                neighbor_storage_list = []

        return nearest_neighbors_list

I am very certain that there is a much better way of solving this issue.

Even though it works, I would like to write better, cleaner code.

If you have a better solution than mine please share below. Thank you for your time!


回答1:


This will work -

import numpy as np

N = 3
X = np.array([1, 8, 14, 15, 17, 100, 35]).astype(np.float32)

d = np.abs(X[None, :] - X[:, None])
np.fill_diagonal(d, np.inf)

min_array = d.min(0)

par = np.argpartition(min_array, N)

print(X[par[:N]])

This gives -

[14. 15. 17.]

First you create a matrix of differences between every pair of elements. You can then choose the n least differences out of it and use that to find the indices of the elements you are looking for. Just make sure to mask out the diagonal elements of the matrix becasue they are always going to be zero (since they will represent the distance between every point to itself).




回答2:


If there is no particular dependency on numpy - we could solve this with some comprehensions as well -

lst = [1, 8, 14, 15, 17, 100, 35]

#Make pairs and remove duplicates using a set
lst_pairs = set([tuple(sorted((x, y))) for x in lst for y in lst if x != y])

#Make a dictionary mapping a pair to distance
dist_dict = {k: k[1] - k[0] for k in lst_pairs}

#extract the top 3 by distance
nearest_n = sorted(dist_dict, key=lambda x: dist_dict[x])[:3]

#unpack the pairs to unique points
nearest_n_pts = set([pt for pair in nearest_n for pt in pair])

#nearest_n_pts is [14, 15, 17]


来源:https://stackoverflow.com/questions/66005832/how-to-find-nearest-neighbors-in-a-list-in-python

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!