K最近邻(k-Nearest Neighbor,KNN)

独自空忆成欢 提交于 2019-11-26 17:13:20

前言:

            学习是就像卷心菜,一层一层的逐步接近菜心,但你永远拨不到最里面那层,无限接近,之前写了一篇关于KNN的算法,使用Python写的,感觉最近又有了新的理解,理解又深刻了一点,so

介绍算法

认识

            编程常常说学会三大语句结构,走遍世界全不怕,想一想人做的操作基本上也都是 分三步,识别(感知),判断(决定),操作(行动),比如你写一个字,首先要识别到纸和笔,然后决定在哪里写,然后开始写,但是3者又是嵌套关系,在你识别纸和笔的时候,你又做了这三个步骤,1 识别看到  2 判断是不是纸 3 拿过来放到桌子上 还是不动。

           那我们常会有这个东西是什么的问题,这个是笔吗?是我就拿过来用,不是我就在找,那识别这个东西是什么?其实根本是一个分类问题,如果这个是笔,那么我之前用过笔,知道笔能写字,如果是铅笔,有写一段时间需要转笔刀或刀,如果是自动笔,一段时间需要加笔芯,而且要轻点写容易断,如果是油笔需要换笔芯,分类后我们能借助之前这个类型物品的经验,进行新的判断操作。

           KNN算法就是解决分类问题的算法,是机器学习算法的一种,k-Nearest Neighbor  意思就是离着k最近的邻居。

如何分类

1 简单理解

           如图,这张图很经典,有蓝色的方块和红色的三角,这代表2类型的数据,现在放入一个未知类型的数据绿色的圆圈,那他是三角还是方块?如果k=3    那就是找到离着绿色最近的3个数据,如果三角的多,就判定他是三角类,如果方块多就判断他是方块,如图k=5的时候有3个方块,和2个三角,所以k值对分类很有影响。

简单了解后你肯定会有疑问,难道我把绿色的圆随便放一个地方就分类了?当然不是,继续看下面的例子。

2分类电影

电影名称

打斗镜头

接吻镜头

电影类型

电影1

10

101

爱情片

电影2

7

89

爱情片

电影3

108

5

动作片

电影4

115

8

动作片

如果现在有一个电影5 打斗镜头12接吻镜头120   按照knn算法 如果 k=3这个电影是什么分类?

         是爱情片,因为离着电影5 最近的3个电影是电影4和电影3 还有电影1,其中2个是爱情片,所以我们把这个未知的电影划到爱情片的分类。

 

3 如何判断距离?

电影5离着哪个电影近肯定是算出来的,在平面直角坐标系中,距离计算是

 

 

4 为什么距离的越近就是一类的

          每个点都是一个实体,每个实体有多个属性,多个属性,根据多个属性算两个点距离,如果越近,说明2个点约相似,算距离的公式有很多种,文章下面会用到欧氏距离。

算法原理

          其实KNN算法的原理很简单,就是找K个临近的,临近的最多是什么类型的,就判断当前实体是什么类型的,一般K都是奇数 ,如果是偶数  2个A类型 2个B类型就没法判断了对吧。

          那么问题来了,如果画图,我们用眼睛看很容易就判别出离得最近的K个,但是实际上  数据不止2个维度,比如电影还有喜剧片,也可以加笑的画面来判断是否是喜剧片,那时候就是3维了,不能画在平面上,现实是实体会有很多维度,这时候多维度算距离有专门的距离公式

思考实现

思考1

            1 把所有的实体放入数组中

            2 遍历数组 每个元素计算要新加元素的距离

            3 按照距离排序

            4 取出离着最近的K个元素

            想法:这样虽然能实现,但是每个元素都要循环一次算距离消耗实在太大了

思考2

            其实也没啥思考看书,书上用kd树实现的

            1 构建kd树

            2 使用kd树找出最近的K个

如何构建kd树

        首先kd树是一颗树,就是二叉树那种树,嗯0 0,因为数据一般有多个维度,这里拿2个维度举例子,要想看2维 要先看1维度   1  2 3  4 5 6 7构建树

             第一步 找中位数

             第二部 按照中位数划分  小的在左边 大的在右边

             第三步 根据分成2组  123  和567 分别重新执行第一步,直到不可再分

一维是不是很简单  那么看  2维度

             数据集{A(1,3),B(12.4),C(3,9),D(31,22),E(,34,11),F(,100,3),G(123,22) }

             第一步  按照x 找中位数

             列出x  A(1),B(12),C(3),D(31),E(,34),F(,100),G(123)

             排序 x  A(1),C(3),B(12),D(31),E(,34),F(,100),G(123)

             中位 划分   A(1),C(3),B(12),      D(31)         ,E(,34),F(,100),G(123)

现在树是这样的

第二步  按y 划分

             列出 y  排序

             A(3),B(4),C(9)

              F(3),E(11),G(22)

             划分之后树就变成了这样

第三步

如果还没构建完毕,就重复第一步和第二步,有多少个维度,就循环维度划分每一层树

如何构建kd树图版本

java代码 

package KD树;


public class Node implements Comparable<KD树.Node> {
    public double[] data;//树上节点的数据  是一个多维的向量
    public double distance;//与当前查询点的距离  初始化的时候是没有的
    public KD树.Node left, right, parent;//左右子节点  以及父节点
    public int dim = -1;//维度  建立树的时候判断的维度

    public Node(double[] data) {
        this.data = data;
    }

    /**
     * 返回指定索引上的数值
     *
     * @param index
     * @return
     */
    public double getData(int index) {
        if (data == null || data.length <= index)
            return Integer.MIN_VALUE;
        return data[index];
    }

    @Override
    public int compareTo(KD树.Node o) {
        if (this.distance > o.distance)
            return 1;
        else if (this.distance == o.distance)
            return 0;
        else return -1;
    }

    /**
     * 计算距离 这里返回欧式距离
     *
     * @param that
     * @return
     */
    public double computeDistance(KD树.Node that) {
        if (this.data == null || that.data == null || this.data.length != that.data.length)
            return Double.MAX_VALUE;//出问题了  距离最远
        double d = 0;
        for (int i = 0; i < this.data.length; i++) {
            d += Math.pow(this.data[i] - that.data[i], 2);
        }

        return Math.sqrt(d);
    }

    public String toString() {
        if (data == null || data.length == 0)
            return null;
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < data.length; i++)
            sb.append(data[i] + " ");
        sb.append(" d:" + this.distance);
        return sb.toString();
    }
}

package KD树;

public class BinaryTreeOrder {
    public void preOrder(Node root) {
        if (root != null) {
            System.out.print(root.toString());
            preOrder(root.left);
            preOrder(root.right);

        }
    }
}
package KD树;

import java.util.ArrayList;
import java.util.List;

public class kd_main {


    public static void main(String[] args) {
        List<Node> nodeList = new ArrayList<Node>();

        nodeList.add(new Node(new double[]{5, 4}));
        nodeList.add(new Node(new double[]{9, 6}));

        nodeList.add(new Node(new double[]{8, 1}));
        nodeList.add(new Node(new double[]{7, 2}));
        nodeList.add(new Node(new double[]{2, 3}));
        nodeList.add(new Node(new double[]{4, 7}));
        nodeList.add(new Node(new double[]{4, 3}));
        nodeList.add(new Node(new double[]{1, 3}));

        kd_main kdTree = new kd_main();
        //构建二叉树
        Node root = kdTree.buildKDTree(nodeList, 0);
        //打印
        new BinaryTreeOrder().preOrder(root);
        for (Node node : nodeList) {
            String left = "空";
            String right = "空";
            if (node.left != null) {
                left = node.left.toString();
            }
            if (node.right != null) {
                right = node.right.toString();
            }
            System.out.println(node.toString() + "-->" + left + "-->" + right);
        }
        System.out.println(root);
        System.out.println(kdTree.searchKNN(root, new Node(new double[]{2.1, 3.1}), 2));
        System.out.println(kdTree.searchKNN(root, new Node(new double[]{2, 4.5}), 1));
        System.out.println(kdTree.searchKNN(root, new Node(new double[]{2, 4.5}), 3));
        System.out.println(kdTree.searchKNN(root, new Node(new double[]{6, 1}), 5));

    }


    /**
     * 构建kd树  返回根节点
     *
     * @param nodeList
     * @param index
     * @return
     */
    public Node buildKDTree(List<Node> nodeList, int index) {
        if (nodeList == null || nodeList.size() == 0)
            return null;
        quickSortForMedian(nodeList, index, 0, nodeList.size() - 1);//中位数排序
        Node root = nodeList.get(nodeList.size() / 2);//中位数 当做根节点
        root.dim = index;
        List<Node> leftNodeList = new ArrayList<Node>();//放入左侧区域的节点  包括包含与中位数等值的节点-_-
        List<Node> rightNodeList = new ArrayList<Node>();

        for (Node node : nodeList) {
            if (root != node) {
                if (node.getData(index) <= root.getData(index))
                    leftNodeList.add(node);//左子区域 包含与中位数等值的节点
                else
                    rightNodeList.add(node);
            }
        }

        //计算从哪一维度切分
        int newIndex = index + 1;//进入下一个维度
        if (newIndex >= root.data.length)
            newIndex = 0;//从0维度开始再算


        root.left = buildKDTree(leftNodeList, newIndex);//添加左右子区域
        root.right = buildKDTree(rightNodeList, newIndex);

        if (root.left != null)
            root.left.parent = root;//添加父指针
        if (root.right != null)
            root.right.parent = root;//添加父指针
        return root;
    }


    /**
     * 查询最近邻
     *
     * @param root kd树
     * @param q    查询点
     * @param k
     * @return
     */
    public List<Node> searchKNN(Node root, Node q, int k) {
        List<Node> knnList = new ArrayList<Node>();
        searchBrother(knnList, root, q, k);
        return knnList;
    }

    /**
     * searhchBrother
     *
     * @param knnList
     * @param k
     * @param q
     */
    public void searchBrother(List<Node> knnList, Node root, Node q, int k) {
//         Node almostNNode=root;//近似最近点
        Node leafNNode = searchLeaf(root, q);
        double curD = q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
        leafNNode.distance = curD;
        maintainMaxHeap(knnList, leafNNode, k);
        System.out.println("leaf1" + leafNNode.getData(leafNNode.parent.dim));
        while (leafNNode != root) {
            if (getBrother(leafNNode) != null) {
                Node brother = getBrother(leafNNode);
                System.out.println("brother1" + brother.getData(brother.parent.dim));
                if (curD > Math.abs(q.getData(leafNNode.parent.dim) - leafNNode.parent.getData(leafNNode.parent.dim)) || knnList.size() < k) {
                    //这样可能在另一个子区域中存在更加近似的点
                    searchBrother(knnList, brother, q, k);
                }
            }
            System.out.println("leaf2" + leafNNode.getData(leafNNode.parent.dim));
            leafNNode = leafNNode.parent;//返回上一级
            double rootD = q.computeDistance(leafNNode);//最近近似点与查询点的距离 也就是球体的半径
            leafNNode.distance = rootD;
            maintainMaxHeap(knnList, leafNNode, k);
        }
    }


    /**
     * 获取兄弟节点
     *
     * @param node
     * @return
     */
    public Node getBrother(Node node) {
        if (node == node.parent.left)
            return node.parent.right;
        else
            return node.parent.left;
    }

    /**
     * 查询到叶子节点
     *
     * @param root
     * @param q
     * @return
     */
    public Node searchLeaf(Node root, Node q) {
        Node leaf = root, next = null;
        int index = 0;
        while (leaf.left != null || leaf.right != null) {
            if (q.getData(index) < leaf.getData(index)) {
                next = leaf.left;//进入左侧
            } else if (q.getData(index) > leaf.getData(index)) {
                next = leaf.right;
            } else {
                //当取到中位数时  判断左右子区域哪个更加近
                if (q.computeDistance(leaf.left) < q.computeDistance(leaf.right))
                    next = leaf.left;
                else
                    next = leaf.right;
            }
            if (next == null)
                break;//下一个节点是空时  结束了
            else {
                leaf = next;
                if (++index >= root.data.length)
                    index = 0;
            }
        }

        return leaf;
    }

    /**
     * 维护一个k的最大堆
     *
     * @param listNode
     * @param newNode
     * @param k
     */
    public void maintainMaxHeap(List<Node> listNode, Node newNode, int k) {
        if (listNode.size() < k) {
            maxHeapFixUp(listNode, newNode);//不足k个堆   直接向上修复
        } else if (newNode.distance < listNode.get(0).distance) {
            //比堆顶的要小   还需要向下修复 覆盖堆顶
            maxHeapFixDown(listNode, newNode);
        }
    }

    /**
     * 从上往下修复  将会覆盖第一个节点
     *
     * @param listNode
     * @param newNode
     */
    private void maxHeapFixDown(List<Node> listNode, Node newNode) {
        listNode.set(0, newNode);
        int i = 0;
        int j = i * 2 + 1;
        while (j < listNode.size()) {
            if (j + 1 < listNode.size() && listNode.get(j).distance < listNode.get(j + 1).distance)
                j++;//选出子结点中较大的点,第一个条件是要满足右子树不为空

            if (listNode.get(i).distance >= listNode.get(j).distance)
                break;

            Node t = listNode.get(i);
            listNode.set(i, listNode.get(j));
            listNode.set(j, t);

            i = j;
            j = i * 2 + 1;
        }
    }

    private void maxHeapFixUp(List<Node> listNode, Node newNode) {
        listNode.add(newNode);
        int j = listNode.size() - 1;
        int i = (j + 1) / 2 - 1;//i是j的parent节点
        while (i >= 0) {

            if (listNode.get(i).distance >= listNode.get(j).distance)
                break;

            Node t = listNode.get(i);
            listNode.set(i, listNode.get(j));
            listNode.set(j, t);

            j = i;
            i = (j + 1) / 2 - 1;
        }
    }


    /**
     * 使用快排进进行一个中位数的查找  完了之后返回的数组size/2即中位数
     *
     * @param nodeList
     * @param index    某一个维度
     * @param left
     * @param right
     */
    private void quickSortForMedian(List<Node> nodeList, int index, int left, int right) {
        if (left >= right || nodeList.size() <= 0)
            return;

        Node kn = nodeList.get(left);//随便拿出一个节点
        double k = kn.getData(index);//取得向量指定索引的值

        int i = left, j = right;

        //控制每一次遍历的结束条件,i与j相遇  简单说就是比kn小的放左边,比kn大的放右边
        while (i < j) {
            //从右向左找一个小于i处值的值,并填入i的位置
            while (nodeList.get(j).getData(index) >= k && i < j)
                j--;
            nodeList.set(i, nodeList.get(j));
            //从左向右找一个大于i处值的值,并填入j的位置
            while (nodeList.get(i).getData(index) <= k && i < j)
                i++;
            nodeList.set(j, nodeList.get(i));
        }

        nodeList.set(i, kn);


        if (i == nodeList.size() / 2)
            return;//完成中位数的排序了,但并不是完成了所有数的排序,这个终止条件只是保证中位数是正确的。去掉该条件,可以保证在递归的作用下,将所有的树
            //将所有的数进行排序

        else if (i < nodeList.size() / 2) {
            quickSortForMedian(nodeList, index, i + 1, right);//只需要排序右边就可以了
        } else {
            quickSortForMedian(nodeList, index, left, i - 1);//只需要排序左边就可以了
        }

//        for (Node node : nodeList) {
//            System.out.println(node.getData(index));
//        }
    }
}

 

后记

           先这样吧回来再改改

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!