树链剖分主要解决在树上某条路径上或某棵子树的sum与最值
入门树链剖分,最重要的概念是重儿子,用son[]记录。son[i]代表的是以i为根,节点最多子树根的编号。通过son,我们将树中的边分为两种,轻边和重边。重边是每一个点与它重儿子的连边。将链续的重边串起来,构成了一条条重链。对于每个点,它一定在某条重链上,特殊的,一个单独的点也可以是一条重链。对于一条重链上的点,他们的dfs序是连续的,对与一颗子树上的点,他们的dfs序也是连续的,于是我们将树上的点转化成一个区间,在区间用线段树上求解或修改。
树链剖分的核心便是如何将树剖分成若干链。
在此之前,先了解以下七个参数的意义。
fa[x] x的父亲节点编号
dep[x] x的深度
size[x] 以x为根子树的节点,用来求son
seg[x] 以son为基础的dfs2序,即其在线段树上的编号
rev[x] 用来将线段树上的编号转化为原编号。
son[x] 记录重儿子
top[x] 记录x所在重链dep最小的点
树链剖分需要的前置知识有DFS+LCA+线段树
我们在两次dfs中求出上述七个参数
dfs1:求出dep,f,size,son。
inline void dfs1(int u,int f){
int i,v;
size[u]=1;
fa[u]=f;
dep[u]=dep[f]+1;
for(i=fir[u];v=to[i],i;i=nex[i]){
if(v!=f){
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]])//更新重儿子
son[u]=v;
}
}
}
dfs2:求出rev,seg,top。
inline void dfs2(int u,int f){
int i,v;
if(son[u]){//优先遍历重儿子。
seg[son[u]]=++tim;
rev[tim]=son[u];
top[son[u]]=top[u];//重儿子的top,就是u的top。
dfs2(son[u],u);
}
for(i=fir[u];v=to[i],i;i=nex[i])
if(!top[v]){//访问轻边
seg[v]=++tim;
rev[tim]=v;
top[v]=v;//轻边单独开了一条链,top是本身
dfs2(v,u);
}
}
两遍dfs就把整棵树划分为若干条链,剩下的就交给线段树解决了。
首先是建树
inline void build(int k,int l,int r){
if(l==r){
sum[k]=ma[k]=w[l];//w是每个点的全值
return;
}
int mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
ma[k]=max(ma[k<<1],ma[k<<1|1]);
sum[k]=sum[k<<1]+sum[k<<1|1];
}
线段树的查询和修改类似。
inline void change(int k,int l,int r,int val,int pos){//pos是当前要修改的的位置,val是改变后的值
if(l>pos||r<pos)
return;
if(l==r&&l==pos){
ma[k]=sum[k]=val;
return;
}
int mid=l+r>>1;
change(k<<1,l,mid,val,pos);
change(k<<1|1,mid+1,r,val,pos);
sum[k]=sum[k<<1]+sum[k<<1|1];
ma[k]=max(ma[k<<1],ma[k<<1|1]);
}
inline void query(int k,int x,int y,int l,int r){//l~r为需要修改的区间
if(x>r||y<l)
return;
if(x>=l&&y<=r){
SUM+=sum[k];
MAX=max(ma[k],MAX);
return;
}
int mid=x+y>>1;
query(k<<1,x,mid,l,r);
query(k<<1|1,mid+1,y,l,r);
}
最后,我们只需要知道只需要知道哪些点对这条路径有贡献,统计他们的贡献即可。
inline void ask(int x,int y){
inline void ask(int x,int y){ int fx=top[x],fy=top[y];
while(fx!=fy){//如果他们不在同一重链上
if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);//选取深度大的那一条,
query(1,1,tim,seg[fx],seg[x]);//注意要将原编号转化为dfs序编号
x=fa[x],fx=top[x];
} //如果他们在一条链上了,再统计x~y路径的贡献
if(dep[x]>dep[y]) swap(x,y);//保证x的编号小等于y
query(1,1,tim,seg[x],seg[y]);
}
下面附上一道模板题
https://www.lydsy.com/JudgeOnline/problem.php?id=1036


#include<cstdio>
#include<iostream>
#include<cstring>
#define max(x,y) (x>y?x:y)
#define N 100000
using namespace std;
int n,m,tot,tim,SUM,MAX;
int fir[N],to[N],nex[N];
int seg[N],rev[N],size[N],son[N],dep[N],top[N],fa[N];
int sum[N],ma[N],w[N];
inline void r(int &x){
bool sign=1;
x=0;
char ch=getchar();
while(ch<'0'||ch>'9') ch=getchar();
if(ch=='-') sign=0,ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
x=sign?x:-x;
}
inline void add(int x,int y){
to[++tot]=y,nex[tot]=fir[x],fir[x]=tot;
to[++tot]=x,nex[tot]=fir[y],fir[y]=tot;
}
inline void dfs1(int u,int f){
int i,v;
size[u]=1;
fa[u]=f;
dep[u]=dep[f]+1;
for(i=fir[u];v=to[i],i;i=nex[i]){
if(v!=f){
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]])
son[u]=v;
}
}
}
inline void dfs2(int u,int f){
int i,v;
if(son[u]){
seg[son[u]]=++tim;
rev[tim]=son[u];
top[son[u]]=top[u];
dfs2(son[u],u);
}
for(i=fir[u];v=to[i],i;i=nex[i])
if(!top[v]){
seg[v]=++tim;
rev[tim]=v;
top[v]=v;
dfs2(v,u);
}
}
inline void build(int k,int l,int r){
if(l==r){
sum[k]=ma[k]=w[l];
return;
}
int mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
ma[k]=max(ma[k<<1],ma[k<<1|1]);
sum[k]=sum[k<<1]+sum[k<<1|1];
}
inline void change(int k,int l,int r,int val,int pos){
if(l>pos||r<pos)
return;
if(l==r&&l==pos){
ma[k]=sum[k]=val;
return;
}
int mid=l+r>>1;
change(k<<1,l,mid,val,pos);
change(k<<1|1,mid+1,r,val,pos);
sum[k]=sum[k<<1]+sum[k<<1|1];
ma[k]=max(ma[k<<1],ma[k<<1|1]);
}
inline void query(int k,int x,int y,int l,int r){
if(x>r||y<l)
return;
if(x>=l&&y<=r){
SUM+=sum[k];
MAX=max(ma[k],MAX);
return;
}
int mid=x+y>>1;
query(k<<1,x,mid,l,r);
query(k<<1|1,mid+1,y,l,r);
}
inline void ask(int x,int y){
int fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);
query(1,1,tim,seg[fx],seg[x]);
x=fa[x],fx=top[x];
}
if(dep[x]>dep[y]) swap(x,y);
query(1,1,tim,seg[x],seg[y]);
}
int main()
{
int i,j,x,y;
char op[10];
r(n);
for(i=1;i<n;i++){
r(x),r(y);
add(x,y);
}
for(i=1;i<=n;i++)
r(w[i]);
tim=seg[1]=top[1]=rev[1]=1;
dfs1(1,0);
dfs2(1,0);
build(1,1,tim);
r(m);
for(i=1;i<=m;i++){
scanf("%s",op);
r(x),r(y);
SUM=0;
MAX=-N;
switch(op[1]){
case 'M':{
ask(x,y);
printf("%d\n",MAX);
break;
}
case 'S':{
ask(x,y);
printf("%d\n",SUM);
break;
}
case 'H':{
change(1,1,tim,y,seg[x]);
break;
}
}
}
}
2019-09-04
