numpy实现KNN代码

匿名 (未验证) 提交于 2019-12-03 00:03:02

代码参考征哥博客:传送门

简介:

KNN的基本思想是根据与测试样本相邻最近的k个样本的标签,去决定该样本的预测值。KNN有三个要素:k值选择,距离度量方式和决策准则。

KNN没有显式的训练过程,计算基本都在预测阶段。

1. K值选择

一般先选取一个较小的k值,然后通过交叉验证来确定k的取值。

2. 距离度量方式

一般选择欧氏距离、曼哈顿距离或余弦相似度。

3. 决策准则

一般分类用多数表决法,回归用平均法。

 

一般情况下直接遍历一遍样本,计算测试样本与训练集中每个样本的距离,然后选择最近的k个。但是这样在样本集非常大时效率不高,优化方案是使用KD树或者球树来寻找k近邻,具体细节可以参考:传送门

 

代码:

import numpy as np from collections import Counter   class KNN:     def __init__(self, task_type='classification'):         self.train_data = None         self.train_label = None         self.task_type = task_type      def fit(self, train_data, train_label):         self.train_data = np.array(train_data)         self.train_label = np.array(train_label)      def predict(self, test_data, k=3, distance='l2'):         test_data = np.array(test_data)         preds = []         for x in test_data:             if distance == 'l1':                 dists = self.l1_distance(x)             elif distance == 'l2':                 dists = self.l2_distance(x)             else:                 raise ValueError('wrong distance type')             sorted_idx = np.argsort(dists)             knearnest_labels = self.train_label[sorted_idx[:k]]             pred = None             if self.task_type == 'regression':                 pred = np.mean(knearnest_labels)             elif self.task_type == 'classification':                 pred = Counter(knearnest_labels).most_common(1)[0][0]             preds.append(pred)         return preds      def l1_distance(self, x):         return np.sum(np.abs(self.train_data-x), axis=1)      def l2_distance(self, x):         return np.sum(np.square(self.train_data-x), axis=1)   if __name__ == '__main__':     train_data = [[1, 1, 1], [2, 2, 2], [10, 10, 10], [13, 13, 13]]     # train_label = ['aa', 'aa', 'bb', 'bb']     train_label = [1, 2, 30, 60]     test_data = [[3, 2, 4], [9, 13, 11], [10, 20, 10]]     knn = KNN(task_type='regression')     knn.fit(train_data, train_label)     preds = knn.predict(test_data, k=2)     print(preds)

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!