树链剖分模板
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define LL long long
using namespace std;
const int MAXN = 2e5+10;//点的个数
struct node{
int l,r;
int sum,laze;//线段树
}tree[MAXN<<2];
struct edge{
int next,to;
}e[MAXN<<1];
int a[MAXN];
int head[MAXN];
int siz[MAXN];//子树的大小
int top[MAXN];//重链的顶端
int son[MAXN];//每个节点的重儿子
int d[MAXN];//每个节点的深度
int fa[MAXN];//每个节点的父亲节点
int id[MAXN];//每个节点的DFS序
int rk[MAXN];//每个DFS序对应的节点
inline int L(int x){return x<<1;};
inline int R(int x){return x<<1|1;};
inline int MID(int l,int r){return (l+r)>>1;};
int n,m,r,MOD,uu,vv;
int cnt=0;
void add(int x,int y){
e[++cnt].next=head[x];
e[cnt].to=y;
head[x]=cnt;
}
void dfs1(int u,int f,int depth){
d[u]=depth;
fa[u]=f;
siz[u]=1; //这个点本身size=1
for (int i=head[u];i;i=e[i].next){
int v=e[i].to;
if (v==f)
continue;
dfs1(v,u,depth+1); //层次深度+1
siz[u]+=siz[v]; //子节点的size已经被处理,用它来更新父亲节点
if (siz[v]>siz[son[u]])
son[u]=v; //选取size最大的作为重儿子并不断更新
}
}
void dfs2(int u,int t){
top[u]=t; //标记这个节点,重链顶端
id[u]=++cnt; //标记DFS序列
rk[cnt]=a[u];
if (!son[u])//如果到根节点
return;
dfs2(son[u],t);
//我们选择有限进入重儿子,让重儿子的DFS序连续
for (int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if (v!=son[u] && v!=fa[u])//如果一个点不是重儿子,并且这个节点也不是其父亲节点
dfs2(v,v);
}
}
void push_down(int root){
if (tree[root].laze){
tree[L(root)].laze+=tree[root].laze;
tree[R(root)].laze+=tree[root].laze;
tree[L(root)].sum+=(tree[L(root)].r-tree[L(root)].l+1)*tree[root].laze;
tree[R(root)].sum+=(tree[R(root)].r-tree[R(root)].l+1)*tree[root].laze;
tree[L(root)].sum%=MOD;
tree[R(root)].sum%=MOD;
tree[root].laze=0;
}
}
void buildtree(int root,int l,int r){
tree[root].l=l;
tree[root].r=r;
if (l==r){
tree[root].sum=rk[l]%MOD;
return ;
}
int mid=MID(l,r);
buildtree(L(root),l,mid);
buildtree(R(root),mid+1,r);
tree[root].sum=(tree[L(root)].sum+tree[R(root)].sum)%MOD;
}
int query(int root,int ql,int qr){
int l=tree[root].l;
int r=tree[root].r;
int res=0;
if (ql<=l && r<=qr){
return tree[root].sum;
}
push_down(root);
int mid=MID(l,r);
if (qr<=mid){
res=query(L(root),ql,qr);
}else if (ql>mid){
res=query(R(root),ql,qr);
}else {
res=query(L(root),ql,mid);
res+=query(R(root),mid+1,qr);
}
return res%MOD;
}
void update(int root,int ul,int ur,int w){
int l=tree[root].l;
int r=tree[root].r;
if (ul<=l && r<=ur){
tree[root].laze+=w;
tree[root].sum+=(r-l+1)*w;
return ;
}
push_down(root);
int mid=MID(l,r);
if (ur<=mid){
update(L(root),ul,ur,w);
}else if (ul>mid){
update(R(root),ul,ur,w);
}else{
update(L(root),ul,mid,w);
update(R(root),mid+1,ur,w);
}
tree[root].sum=(tree[L(root)].sum+tree[R(root)].sum)%MOD;
}
int qRange(int x,int y){
int ans=0;
while(top[x]!=top[y]){//不在一条链上
if (d[top[x]]<d[top[y]])swap(x,y);//把x变成深的节点
ans+=query(1,id[top[x]],id[x]);//求和
ans%=MOD;
x=fa[top[x]];//在跳到链的顶端的上面一个点
}//直到两个点处于一条链上
if (d[x]>d[y])swap(x,y);//在同一层后继续
ans+=query(1,id[x],id[y]);
return ans%MOD;
}
void updRange(int x,int y,int k){
k%=MOD;
while(top[x]!=top[y]){
if (d[top[x]]<d[top[y]])swap(x,y);
update(1,id[top[x]],id[x],k);
x=fa[top[x]];
}
if (d[x]>d[y])swap(x,y);
update(1,id[x],id[y],k);
}
int qson(int x){
return query(1,id[x],id[x]+siz[x]-1)%MOD;
}
void updson(int x,int k){
// cout<<id[x]<<" "<<siz[x]<<endl;
update(1,id[x],id[x]+siz[x]-1,k);
}
int main(){
while(~scanf("%d%d%d%d",&n,&m,&r,&MOD)){
memset(head,0,sizeof(head));
memset(id,0,sizeof(id));
for (int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for (int i=1;i<n;i++){
scanf("%d%d",&uu,&vv);
add(uu,vv);
add(vv,uu);
}
cnt=0;
dfs1(r,0,1);
dfs2(r,r);
buildtree(1,1,n);
while(m--){
int op,x,y,z;
scanf("%d",&op);
if (op==1){
scanf("%d%d%d",&x,&y,&z);
updRange(x,y,z);
}else if (op==2){
scanf("%d%d",&x,&y);
printf("%d\n",qRange(x,y));
}else if (op==3){
scanf("%d%d",&x,&y);
// cout<<x<<" "<<y<<endl;
updson(x,y);
}else {
scanf("%d",&x);
printf("%d\n",qson(x));
}
}
}
return 0;
}