KD-tree讲解

試著忘記壹切 提交于 2019-11-29 09:30:38
  KD-tree  讲解
 by simb351 

  应神犇junble19768的要求,来水一发KD-tree讲解。学习过程中发现关于KD-tree的资源实在是少,
在OI中的应用更是少之又少。虽然这东西很水,随便嘴一嘴就能口胡出来。所以写一篇讲解。

 前置芝士:
1. 基础BST姿势
2. 替罪羊树
3. 优良的空间感 比如能脑补出k维空间QWQ

KD-tree的应用:
  KD-tree主要解决对K维数据的管理,比如多维偏序。但是本弱发现目前OI中KD-tree主流考法为维护
二维平面中区间的信息。比如ION 9102 弹跳 但是好像能被菜鸡simba口胡的暴力算几卡过。


KD-tree的原理:
  考虑从一般BST中类比过来---对于其中一个节点其左边节点的值恒小于它本身,右边反之。实际上是
把所有节点的值从中间分开。如果说BST是对一个一维线段的分割,那么KD-tree就是对K维空间分割。最
终在小的空间内统计答案。说人话就是对K维按顺序均匀分割查哪部分就去哪个块中查找。因为是按序分
割,所以找到答案空间的时间是nlogn至nsqrtn我信你个鬼,不带O2天天被卡。
为了划分空间,KD-tree在第i层维护第i%k维的信息,即这一维中比它小的在左子树,大的在右子树。对
于查询就像BST一样就好了。同BST,考虑KD-tree如何保持自身平衡。由于用方差过于优雅,此处选择替
罪羊树一样的思路---拍扁重建。这样KD-tree就愉快的讲完了,撒花。
KD-tree代码实现:
首先是树的结点。
  
 
struct point
{
    int x[DIM]; //DIM☞维度 x表示一个k维向量 
    bool operator < (const point X) const
    {
        return x[now]<X.x[now];
    }
    /* 考虑分割一个维度时,为了让分割更均匀,要尽量选最中间的点 now表示当前维护维度,定义小于号来维护中间的点。*/
}// 存储一个向量 
struct node
{
    int l,r;//左右子树
    int sze;//子树大小
    int minn[DIM];//此节点维护的空间中第i维的最小值
    int maxx[DIM];//此节点维护的空间中第i维的最大值
    point data;//这个点所维护的向量
}//KD-tree上一个节点的定义

 


  
 
  
然后是维护一个节点的信息。
 
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/

 


 
然后是把一个子树拍扁。
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/

 

接着是将一个序列加到树上,就是把树拍扁后再挂到树上。
 
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/

 

查看这颗树要不要拍扁重建。
void check(int& pos,int dim)
{
    if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
    {rev(pos,0); pos=build(1,tree[pos].sze,dim);}
}// 字面意思 不平衡就拍扁重建 , alpha是替罪羊的平衡因子

 


插入单个点。
 
int insert(int pos,point data,int dim)
{
    if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
    if(data.x[dim]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dim^1);
    else tree[pos].r=insert(tree[pos].r,data,dim^1);
    update(pos); check(pos,dim); return pos;
}//像平衡树一样左右看看加那边,然后挂上节点。最后check一下不让树退化

 

查询,就查询经典问题,给定n个点坐标,以及一个点坐标s 询问n个点中那个离s最近的点是哪个。
 
void query(int pos,point data)
{
    ans=min(ans,dist(data,tree[pos].data)); //用当前点信息更新ans
    int dist_left=INF; int dist_right=INF;
    if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
    if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
    // L,R维护是查询点s到当前点左右子树所维护空间的距离
    if(dist_left<dist_right)
    {
        if(dist_left<ans) query(tree[pos].l,data);
        if(dist_right<ans) query(tree[pos].r,data);
    }
    else 
    {
        if(dist_right<ans) query(tree[pos].r,data);
        if(dist_left<ans) query(tree[pos].l,data);
    }
    // 以当前点为圆心,以ans为半径画圆,如果达不到左/右子树所维护的空间,就不查那边。
}

 


没了,真没了,写写题就好了。
附上bzoj2648代码,就是上面的问题
 
#include<bits/stdc++.h>
#define DEM 2
#define alpha (1.130/2)
#define maxn 1000010
#define INF 0x3f3f3f3f
using namespace std;
int n,m;
int u,v;
int now;
int ans;
int opt;
int root;
int points;
queue<int>Q;
struct point
{
    int x[DEM]; 
    bool operator < (const point X) const {return x[now]<X.x[now];}
}one[maxn];
struct node
{
    int l,r;
    int sze;
    int minn[DEM];
    int maxx[DEM];
    point data;
}tree[maxn];
int dist(point,point);
int get_dist(point,int);
int New();
void update(int);
void rev(int,int);
int build(int,int,int);
void check(int&,int);
int insert(int,point,int);
void query(int,point);
int main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>one[i].x[0]>>one[i].x[1];
    root=build(1,n,0);
    for(int i=1;i<=m;i++)
    {
        cin>>opt>>u>>v;
        if(opt==1) {root=insert(root,(point){u,v},0);}
        else {ans=INF; query(root,(point){u,v}); cout<<ans<<endl;}  
    } 
}
int dist(point A,point B) {return abs(A.x[0]-B.x[0])+abs(A.x[1]-B.x[1]);}
int get_dist(point A,int pos) {int ret=0; for(int i=0;i<DEM;i++) ret+=max(0,A.x[i]-tree[pos].maxx[i])+max(0,tree[pos].minn[i]-A.x[i]); return ret;}
int New()
{
    if(!Q.empty()) {static int tmp; tmp=Q.front(); Q.pop(); return tmp;}
    else return ++points;
}
void update(int pos)
{
    for(int i=0;i<DEM;i++)
    {
        tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
        if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
        if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
        if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
        if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
    }
    tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
}
void rev(int pos,int num)
{
    if(tree[pos].l) rev(tree[pos].l,num);
    one[tree[tree[pos].l].sze+num+1]=tree[pos].data; Q.push(pos);
    if(tree[pos].r) rev(tree[pos].r,tree[tree[pos].l].sze+num+1);
}
int build(int l,int r,int dem)
{
    if(l>r) return 0;
    int mid=(l+r)>>1,pos=New();
    now=dem; nth_element(one+l,one+mid,one+r+1); tree[pos].data=one[mid]; 
    tree[pos].l=build(l,mid-1,dem^1); tree[pos].r=build(mid+1,r,dem^1);
    update(pos); return pos;
}
void check(int& pos,int dem)
{
    if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
    {rev(pos,0); pos=build(1,tree[pos].sze,dem);}
}
int insert(int pos,point data,int dem)
{
    if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
    if(data.x[dem]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dem^1);
    else tree[pos].r=insert(tree[pos].r,data,dem^1);
    update(pos); check(pos,dem); return pos;
}
void query(int pos,point data)
{
    ans=min(ans,dist(data,tree[pos].data));
    int dist_left=INF; int dist_right=INF;
    if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
    if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
    if(dist_left<dist_right)
    {
        if(dist_left<ans) query(tree[pos].l,data);
        if(dist_right<ans) query(tree[pos].r,data);
    }
    else 
    {
        if(dist_right<ans) query(tree[pos].r,data);
        if(dist_left<ans) query(tree[pos].l,data);
    }
}

 


 
 
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!