参考:http://blog.csdn.net/nixinyis/article/details/65445466
【简介】
点分治是一类用于处理树上路径点权统计问题的算法,其利用重心的性质降低复杂度。
【什么是重心】
某个其所有的子树中最大的子树节点数最少的点被称为重心,删去重心后,生成的多棵树尽可能平衡。
【重心的性质】
①重心其所有子树的大小都不超过$\frac{n}{2}$。
②树中所有点到某个点的距离和中,到树的重心的距离和是最小的,如果有两个重心,那么到它们的距离和相同。
③把两棵树通过两个点相连得到一棵新的树,新的树的重心必定在连接两棵树的重心的路径上。
④一棵树添加或删除一个节点,树的重心最多会移动一条边的位置。
点分治的复杂度基于重心的第一个性质。
【点分治】
点分治是对于每一棵子树,都求出它的重心,并且以重心为根跑一遍这棵子树并统计经过重心的路径,因为我们知道重心所有子树的大小都只有原树的一半,也就是我们这么做最多只会递归$logn$层,若一层的复杂度$O(f(n))$,则总的时间复杂度为$O(f(n)logn)$。
接下来以bzoj1468: Tree为例题讲一下点分治。
对于这道题,显然用上方说的方法,对于每一个子树求出dep,排序后两端指针往中间靠拢统计即可。但是可能会统计重复,如下图,如果红色路径加上父亲边的两倍还小于K,那么也会被根节点统计,所以我们还得删去加上这条父亲边之后的贡献。
【点分治的流程】(以bzoj1468为标准,可能部分内容因题目而异)
点分治由getroot(求重心),calc(计算贡献),getdep(求某个子树以重心为根的所有点的深度),dfs构成。
首先是getroot部分,求重心就不多提了,只要找到一个点满足最大的子树最小即可。
getroot的一些注意事项:getroot之前把root清零,把sum改成子树大小,getroot的时候不访问当前点的父亲或者是dfs已经访问过的点,记得搜到每一个点的时候都要先把mx清零。
getdep部分,calc的时候需要先对这棵子树先getdep,用d数组求出以子树根为根的子树的所有点的深度,并且再用一个dep数组把子树里所有点的深度都塞进去,注意dep只存这棵子树的点的深度,而d是以节点编号为下标的。同样不访问当前点的父亲或者是dfs已经访问过的点。
calc部分,形参有子树根和init,init表示子树根的初始深度,用于上面所说的去重。将子树根初始深度设好,然后getdep,接着就可以处理这个子树的信息了。
最后是dfs部分,每次先对当前点calc(x, 0)计算贡献,然后对每一个son,先把答案去掉calc(son, e[i].dis)的贡献,然后对子树getroot求出重心,接着搜子树的重心即可。
代码如下:

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=40010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, K, x, y, z, sum, ans, tot, root, cnt, tott;
int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn];
bool v[maxn];
inline void read(int &k)
{
int f=1; k=0; char c=getchar();
while(c<'0' || c>'9') c=='-'&&(f=-1), c=getchar();
while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
size[x]=1; mx[x]=0;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too] && too!=fa)
{
getroot(too, x);
size[x]+=size[too];
mx[x]=max(mx[x], size[too]);
}
mx[x]=max(mx[x], sum-size[x]);
if(mx[x]<mx[root]) root=x;
}
void getdep(int x, int fa)
{
dep[++cnt]=d[x];
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too] && too!=fa)
{
d[too]=d[x]+e[i].dis;
getdep(too, x);
}
}
int calc(int x, int init)
{
d[x]=init; cnt=0;
getdep(x, 0);
sort(dep+1, dep+cnt+1);
int l=1, r=cnt, sum=0;
while(l<r)
if(dep[l]+dep[r]<=K) sum+=r-l, l++;
else r--;
return sum;
}
void dfs(int x)
{
ans+=calc(x, 0); v[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
ans-=calc(too, e[i].dis);
sum=size[too]; root=0;
getroot(too, 0);
dfs(root);
}
}
int main()
{
read(n);
for(int i=1;i<n;i++)
read(x), read(y), read(z), add(x, y, z), add(y, x, z);
read(K); sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
printf("%d\n", ans);
}
还有一种写法是对每个重心的子树单独搜,对每个子树统计与搜过子树的方案数,这样子统计得到的路径必定经过重心。
【例题】
例1:bzoj2152: 聪聪可可
这题一眼做法是$O(n)$的,但这里只说点分治的做法。
和上方的例题差不多,只是改成每次求子树里深度%3的个数而已,最后答案就是$sum[0]*sum[0]+sum[1]*sum[2]$。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#define MOD(x) ((x)>=3?(x)-3:(x))
using namespace std;
const int maxn=500010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, ans;
int last[maxn], mx[maxn], size[maxn], cnt[3], d[maxn];
bool v[maxn];
inline void read(int &k)
{
int f=1; k=0; char c=getchar();
while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
inline int gcd(int a, int b){return b?gcd(b, a%b):a;}
void getroot(int x, int fa)
{
mx[x]=0; size[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
getroot(too, x);
size[x]+=size[too];
mx[x]=max(mx[x], size[too]);
}
mx[x]=max(sum-size[x], mx[x]);
if(mx[x]<mx[root]) root=x;
}
void getdep(int x, int fa)
{
cnt[d[x]]++;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
d[too]=MOD(d[x]+e[i].dis);
getdep(too, x);
}
}
int calc(int x, int init)
{
d[x]=init;
cnt[0]=cnt[1]=cnt[2]=0;
getdep(x, 0);
return cnt[0]*cnt[0]+(cnt[1]*cnt[2]<<1);
}
void dfs(int x)
{
ans+=calc(x, 0); v[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
ans-=calc(too, e[i].dis);
sum=size[too]; root=0;
getroot(too, 0);
dfs(root);
}
}
int main()
{
read(n);
for(int i=1;i<n;i++)
read(x), read(y), read(z), z%=3, add(x, y, z), add(y, x, z);
sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
printf("%d/%d\n", ans/gcd(n*n, ans), n*n/gcd(n*n, ans));
}
例2:bzoj2599: [IOI2011]Race
这题对于点分治的初学者来说很容易出错(比如我T T)
因为对于每一个子树我们只能统计经过其重心的路径,而这题是求最小边长,按上面两题的做法做就不可行了。对于每一个子树查询其答案的时候,重心的每一个子树先进行更新答案,再记录其深度的值,这样求得的路径就必定经过重心了,然后再统计重心为端点的答案。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
using namespace std;
const int maxn=1000010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, ans, cnt, K;
int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], ecnt[maxn], mn[maxn];
bool v[maxn];
inline void read(int &k)
{
int f=1; k=0; char c=getchar();
while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
mx[x]=0; size[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
getroot(too, x);
size[x]+=size[too];
mx[x]=max(mx[x], size[too]);
}
mx[x]=max(sum-size[x], mx[x]);
if(mx[x]<mx[root]) root=x;
}
void getans(int x, int fa)
{
dep[++cnt]=d[x];
if(d[x]<=K) ans=min(ans, mn[K-d[x]]+ecnt[x]);
if(d[x]==K) ans=min(ans, ecnt[x]);
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
d[too]=d[x]+e[i].dis;
ecnt[too]=ecnt[x]+1;
getans(too, x);
}
}
void update(int x, int fa)
{
if(d[x]<=K) mn[d[x]]=min(mn[d[x]], ecnt[x]);
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too]) update(too, x);
}
void calc(int x)
{
d[x]=ecnt[x]=cnt=0;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too]) d[too]=e[i].dis, ecnt[too]=1, getans(too, x), update(too, x);
for(int i=1;i<=cnt;i++)
if(K>=dep[i]) mn[dep[i]]=inf;
}
void dfs(int x)
{
calc(x); v[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
sum=size[too]; root=0;
getroot(too, 0);
dfs(root);
}
}
int main()
{
read(n); read(K);
for(int i=1;i<n;i++)
read(x), read(y), read(z), x++, y++, add(x, y, z), add(y, x, z);
memset(mn, 32, sizeof(mn)); ans=inf;
sum=n; mx[0]=inf; getroot(1, 0); dfs(root);
printf("%d\n", ans>n?-1:ans);
return 0;
}
例3:bzoj1316: 树上的询问
这题就是在点分的时候扫一遍所有询问,查询是否有这个len即可,有个坑是len==0的时候答案为Yes。

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
using namespace std;
const int maxn=1000010, inf=1e9, maxl=1000010;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, root, cnt, Q;
int last[maxn], mx[maxn], size[maxn], dep[maxn], d[maxn], len[maxn], ans[maxn];
bool v[maxn], vis[maxn];
inline void read(int &k)
{
int f=1; k=0; char c=getchar();
while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
mx[x]=0; size[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
getroot(too, x);
size[x]+=size[too];
mx[x]=max(mx[x], size[too]);
}
mx[x]=max(sum-size[x], mx[x]);
if(mx[x]<mx[root]) root=x;
}
void getans(int x, int fa)
{
dep[++cnt]=d[x];
for(int i=1;i<=Q;i++)
if(len[i]>=d[x])
if(vis[len[i]-d[x]]) ans[i]=1;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
d[too]=d[x]+e[i].dis, getans(too, x);
}
void update(int x, int fa)
{
if(d[x]<=maxl) vis[d[x]]=1;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too]) update(too, x);
}
void calc(int x)
{
d[x]=cnt=0;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too]) d[too]=e[i].dis, getans(too, x), update(too, x);
for(int i=1;i<=cnt;i++)
if(maxl>=dep[i]) vis[dep[i]]=0;
}
void dfs(int x)
{
calc(x); v[x]=1;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
sum=size[too]; root=0;
getroot(too, 0);
dfs(root);
}
}
int main()
{
read(n); read(Q);
for(int i=1;i<n;i++)
read(x), read(y), read(z), add(x, y, z), add(y, x, z);
for(int i=1;i<=Q;i++) read(len[i]);
sum=n; mx[0]=inf; vis[0]=1; getroot(1, 0); dfs(root);
for(int i=1;i<=Q;i++) printf("%s\n", (ans[i] || !len[i])?"Yes":"No");
return 0;
}
例4:bzoj3697: 采药人的路径
设$f[i][1]$为当前子树长度为i的路径有休息点的方案数,$f[i][0]$为当前子树长度为i路径有休息点的方案数,$g[i][0/1]$表示搜过的子树,其他同理。
统计子树间的答案就是$g[j][0]*f[-j][1]+g[j][1]*f[-j][0]+g[j][1]*f[-j][1]$,统计重心为端点的答案就是$g[0][0]*f[0][0]+f[0][1]$。
调了好久,不熟悉啊T T

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#define ll long long
using namespace std;
const int maxn=200010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn<<1];
int n, x, y, z, tot, sum, mxdep, root;
int size[maxn], mx[maxn], d[maxn], dep[maxn], last[maxn], cnt[maxn];
ll ans, f[maxn][2], g[maxn][2];
bool v[maxn];
inline void read(int &k)
{
int f=1; k=0; char c=getchar();
while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
k*=f;
}
inline void add(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
void getroot(int x, int fa)
{
size[x]=1; mx[x]=0;
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
getroot(too, x);
size[x]+=size[too];
mx[x]=max(mx[x], size[too]);
}
mx[x]=max(mx[x], sum-size[x]);
if(mx[x]<mx[root]) root=x;
}
void update(int x, int fa)
{
if(cnt[d[x]+n]) f[d[x]+n][1]++;
else f[d[x]+n][0]++;
cnt[d[x]+n]++; mxdep=max(mxdep, dep[x]);
for(int i=last[x], too;i;i=e[i].pre)
if((too=e[i].too)!=fa && !v[too])
{
d[too]=d[x]+e[i].dis;
dep[too]=dep[x]+1;
update(too, x);
}
cnt[d[x]+n]--;
}
void calc(int x)
{
int mxd=0;
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
d[too]=e[i].dis; dep[too]=1; mxdep=1; update(too, 0);
mxd=max(mxd, mxdep); ans+=g[n][0]*f[n][0]+f[n][1];
for(int j=-mxdep;j<=mxdep;j++)
ans+=g[j+n][0]*f[n-j][1]+g[j+n][1]*f[n-j][0]+g[j+n][1]*f[n-j][1];
for(int j=-mxdep;j<=mxdep;j++)
g[j+n][1]+=f[j+n][1], g[j+n][0]+=f[j+n][0], f[j+n][1]=f[j+n][0]=0;
}
for(int i=n-mxd;i<=n+mxd;i++) g[i][0]=g[i][1]=0;
}
void dfs(int x)
{
v[x]=1; calc(x);
for(int i=last[x], too;i;i=e[i].pre)
if(!v[too=e[i].too])
{
root=0; sum=size[too];
getroot(too, 0);
dfs(root);
}
}
int main()
{
read(n);
for(int i=1;i<n;i++)
read(x), read(y), read(z), add(x, y, z?1:-1), add(y, x, z?1:-1);
mx[0]=inf; sum=n; getroot(1, 0); dfs(root);
printf("%lld\n", ans);
return 0;
}
来源:https://www.cnblogs.com/Sakits/p/8328707.html
