Splay和一般的BST的区别大概就是:每次插入一个元素后,把它旋转到根
其核心操作就是$splay$和$rotate$
Rotate:把$x$转到$x$的父亲的位置上
$rotate$很好理解,自己画个图玩玩就会了,代码要注意细节
void pushup(int x)
{
sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],k=(x == ch[y][1]);
if(z) ch[z][y == ch[z][1]]=x;
fa[x]=z;
ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y;
ch[x][k^1]=y, fa[y]=x;
pushup(y); pushup(x);
}
Splay:把$x$转到目标$goal$的儿子的位置上(若$goal=0$,则把$x$转到根节点)
分三种情况:
1.$x$与$x$的父亲所属的儿子的种类相同(例如$x$是$y=fa[x]$的左儿子,$fa[x]$是$z=fa[fa[x]]$的左儿子)
2.与上面一种相反
3.$fa[fa[x]]=goal$
对于第一种情况,先把$y$转到$z$,再把$x$转到$y$
对于第二种情况,把$x$一路转上去
对于第三种情况,转一次$x$即可
void splay(int x,int goal)
{
while(fa[x] != goal)
{
int y=fa[x],z=fa[y];
if(z != goal)
(x == ch[y][1]) == (y == ch[z][1]) ? rotate(y) : rotate(x);
rotate(x);
}
if(!goal) rt=x;
}
除了以上操作,$Splay$还有一些附带操作,大都很好理解,但是$Delete$和$Find_kth$要注意一下
$Findkth$:找到当前第$k$大的数
细节很多,要想清楚
int Find_kth(int x)
{
int cur=rt;
while(sz[ch[cur][0]] >= x || sz[ch[cur][0]]+cnt[cur] < x)
if(sz[ch[cur][0]] >= x) cur=ch[cur][0];
else x-=sz[ch[cur][0]]+cnt[cur], cur=ch[cur][1];
return cur;
}
$Delete$:删除某个数$x$(若有多个只删除一个)
先把$x$的前驱转到根节点,在把它的后继转到根节点的儿子上
这样要删的数$x$就只可能独立地挂在它的后继的左儿子上
然后直接删就好了
void Delete(int x)
{
int lst=Pre(x),nxt=Suf(x);
splay(lst,0); splay(nxt,lst);
int tar=ch[nxt][0];
if(cnt[tar]>1) --cnt[tar], splay(tar,0);
else ch[nxt][0]=0, splay(nxt,0); //这里splay是为了更新size
}
注意:为了防止边界上出现一些奇奇怪怪的错误,开局先在平衡树内插入$inf$和$-inf$
#include<bits/stdc++.h>
using namespace std;
const int N=110000,inf=2e9+7;
int rt,node,fa[N],ch[N][2],cnt[N],sz[N],val[N];
void pushup(int x)
{
sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
}
void rotate(int x)
{
int y=fa[x],z=fa[y],k=(x == ch[y][1]);
if(z) ch[z][y == ch[z][1]]=x;
fa[x]=z;
ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y;
ch[x][k^1]=y, fa[y]=x;
pushup(y); pushup(x);
}
void splay(int x,int goal)
{
while(fa[x] != goal)
{
int y=fa[x],z=fa[y];
if(z != goal)
(x == ch[y][1]) == (y == ch[z][1]) ? rotate(y) : rotate(x);
rotate(x);
}
if(!goal) rt=x;
}
void Insert(int x)
{
if(!rt)
{
rt=++node,val[node]=x,++cnt[node],++sz[node];
return;
}
int cur=rt;
while(233)
{
++sz[cur];
if(val[cur] == x) {++cnt[cur]; return;}
int z=ch[cur][val[cur]<x];
if(!z)
{
ch[cur][val[cur]<x]=z=++node,fa[z]=cur;
val[z]=x,++cnt[z],++sz[z];
splay(z,0);
return ;
}
cur=z;
}
}
int Find(int x)
{
int cur=rt;
while(val[cur] != x)
cur=ch[cur][val[cur]<x];
return cur;
}
int Pre(int x)
{
int cur=rt,ret;
while(cur)
if(val[cur]<x) ret=cur, cur=ch[cur][1];
else cur=ch[cur][0];
return ret;
}
int Suf(int x)
{
int cur=rt,ret;
while(cur)
{
if(val[cur]>x) ret=cur, cur=ch[cur][0];
else cur=ch[cur][1];
}
return ret;
}
void Delete(int x)
{
int lst=Pre(x),nxt=Suf(x);
splay(lst,0); splay(nxt,lst);
int tar=ch[nxt][0];
if(cnt[tar]>1) --cnt[tar], splay(tar,0);
else ch[nxt][0]=0, splay(nxt,0);
}
int Find_kth(int x)
{
int cur=rt;
while(sz[ch[cur][0]] >= x || sz[ch[cur][0]]+cnt[cur] < x)
if(sz[ch[cur][0]] >= x) cur=ch[cur][0];
else x-=sz[ch[cur][0]]+cnt[cur], cur=ch[cur][1];
return cur;
}
int main()
{
//freopen(" .in","r",stdin); freopen(" .out","w",stdout);
int n,opt,x;
Insert(inf); Insert(-inf);
scanf("%d",&n);
while(n--)
{
scanf("%d%d",&opt,&x);
if(opt == 1) Insert(x);
if(opt == 2) Delete(x);
if(opt == 3) {x=Find(x); splay(x,0); printf("%d\n",sz[ch[x][0]]);}
if(opt == 4) printf("%d\n",val[Find_kth(++x)]);
if(opt == 5) printf("%d\n",val[Pre(x)]);
if(opt == 6) printf("%d\n",val[Suf(x)]);
}
}
来源:https://www.cnblogs.com/w19567/p/12191274.html