树链剖分,感觉是一个很神奇的东西,但是其实并不是那样的
树链剖分其实就是一个线段树
线段树处理的是连续区间,所以当你要加的时候都是连续区间修改
所以可以用轻重链的方式将树分解成为链条,然后用线段树处理
可以很容易看到,为什么用的是dfs但不是用的是bfs呢
因为dfs保持了重链是连续的,所以可以用top[x]记录已x为节点的重链最上方,一个点也包含在重链内
若修改区间为(u,v),但是重链的祖先是一起的,所以当他们的LCA相同时,边break
所以现在u,v是连续的
所以查询(u,v)的简单路径和也就处理了
所以说线段树中可以进行的操作在树上也可以执行了
在处理一个问题
在u的子树上加w
所以修改的区间是u在线段树中的位置$(t)$ 到 $t+size(u)-1$
$size$ 记录以它为根 的子节点个数
$deep(x)$ 深度
$father(x)$ 记录父亲
$son(x)$ 它的重儿子
$top(x)$ 所在重路径的顶部节点
$seg(x)$ x在线段树中的编号
$rev(x)$ 线段树中x的位置所对应的树中节点编号

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
inline int read()
{
int f=1,ans=0;char c;
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
return ans*f;
}
int n,val[1200001];
struct node{
int u,v,nex;
}x[1200001];
int head[1200001],cnt;
int size[1200001];
int deep[1200001];
int father[1200001];
int son[1200001];
int top[1200001];
int seg[1200001];
int rev[1200001];
int q,root,mod;
void dfs1(int f,int fath)
{
deep[f]=deep[fath]+1;
father[f]=fath;
size[f]=1;
for(int i=head[f];i!=-1;i=x[i].nex)
{
if(x[i].v==fath) continue;
dfs1(x[i].v,f);
size[f]+=size[x[i].v];
if(size[x[i].v]>size[son[f]]) son[f]=x[i].v;
}
return;
}
void dfs2(int f,int fath)
{
if(son[f])
{
top[son[f]]=top[f];
seg[son[f]]=++seg[0];
rev[seg[0]]=son[f];
dfs2(son[f],f);
}
for(int i=head[f];i!=-1;i=x[i].nex)
{
if(x[i].v==fath) continue;
if(top[x[i].v]) continue;
top[x[i].v]=x[i].v;
seg[x[i].v]=++seg[0];
rev[seg[0]]=x[i].v;
dfs2(x[i].v,f);
}
return;
}
void add(int u,int v)
{
x[cnt].u=u,x[cnt].v=v,x[cnt].nex=head[u],head[u]=cnt++;
}
int ans[1200001],sum[1200001];
void build(int k,int l,int r)
{
if(l==r)
{
ans[k]=val[rev[l]];
return;
}
int mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
ans[k]=ans[k<<1]+ans[k<<1|1];
return;
}
void push_down(int k,int l,int r)
{
int mid=l+r>>1;
ans[k<<1]+=sum[k]*(mid-l+1);sum[k<<1]%=mod;
sum[k<<1]+=sum[k];sum[k<<1]%=mod;
ans[k<<1|1]+=sum[k]*(r-mid);ans[k<<1|1]%=mod;
sum[k<<1|1]+=sum[k];sum[k<<1|1]%=mod;
sum[k]=0;
return;
}
void add(int k,int l,int r,int x,int y,int v)
{
if(x<=l&&r<=y){
sum[k]+=v;
sum[k]%=mod;
ans[k]+=((r-l+1)%mod)*v%mod;
ans[k]%=mod;
return;
}
push_down(k,l,r);
int mid=l+r>>1;
if(x<=mid) add(k<<1,l,mid,x,y,v);
if(mid<y) add(k<<1|1,mid+1,r,x,y,v);
ans[k]=ans[k<<1]+ans[k<<1|1];
ans[k]%=mod;
}
void ask_add(int x,int y,int w)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
add(1,1,seg[0],seg[fx],seg[x],w%mod);
x=father[fx],fx=top[x];
}
if(deep[x]>deep[y]) swap(x,y);
add(1,1,seg[0],seg[x],seg[y],w);
}
int summ;
int query(int k,int l,int r,int x,int y)
{
if(x<=l&&r<=y) return ans[k]%mod;
push_down(k,l,r);
int res=0,mid=l+r>>1;
if(x<=mid) res+=query(k<<1,l,mid,x,y)%mod;
if(mid<y) res+=query(k<<1|1,mid+1,r,x,y)%mod;
return res;
}
int ask(int x,int y)
{
summ=0;
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(deep[fx]<deep[fy]) swap(x,y),swap(fx,fy);
summ+=query(1,1,seg[0],seg[fx],seg[x])%mod;
x=father[fx],fx=top[x];
}
if(deep[x]>deep[y]) swap(x,y);
summ+=query(1,1,seg[0],seg[x],seg[y])%mod;
return summ%mod;
}
int main()
{
memset(head,-1,sizeof(head));
n=read(),q=read(),root=read(),mod=read();
for(int i=1;i<=n;i++) val[i]=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs1(root,0);
seg[0]=1;seg[root]=1;
top[root]=root;
rev[1]=root;
dfs2(root,0);
build(1,1,seg[0]);
while(q--)
{
int s=read();
if(s==1)
{
int u=read(),v=read();
int w=read();
ask_add(u,v,w%mod);
}
if(s==3)
{
summ=0;
int u=read(),v=read();
add(1,1,seg[0],seg[u],seg[u]+size[u]-1,v%mod);
}
if(s==2)
{
int u=read(),v=read();
printf("%lld\n",ask(u,v)%mod);
}
if(s==4)
{
int u=read();
printf("%lld\n",query(1,1,seg[0],seg[u],seg[u]+size[u]-1)%mod);
}
}
return 0;
}
来源:https://www.cnblogs.com/si-rui-yang/p/9703477.html
