题意
给定一棵有n个点的树,有三元组$(a,b,c)$满足a,b,c两两距离相等,求这样的三元组的个数。
n<=100000
题解
初步思考
这种题目有个经典解法就是固定两点求中点。但很遗憾这种做法是$O(n^2)$的。
所以我们用树形dp解决这个问题。
考虑从三点lca处统计贡献,这样比较方便。
状态设计
我们考虑逐步满足法,先统计一个点,再用一个点的数据计算两个点都满足的数据,再用一个点和两个点的数据计算三个点都满足的数据。
于是,我们用$f[i][j]$表示在i的子树内与i距离为j的点的个数,
用$g[i][j]$表示在i的子树内有两个点,且他们再加上一个i子树外的与i距离为j的点就满足题目要求。
之所以这样设置是为了方便将g和f拼在一起,且不重不漏地统计答案。
转移方程
那么怎样转移计算g[i][j]的值呢?
因为第二项是j,由定义得现在还需要一条长度为j的边。
考虑一个个枚举i的儿子,统计贡献。
第一种情况,这两条边,一条在已经枚举过的儿子里找,一条在新加进来的找。
那么i就是这两条边的交点,可得两条边长度为i.
下图中,红色表示现在连接的边,黄色表示需要的边,绿色表示将要配对成一个三元组的边。

这种情况可以用$f[i][j]*f[son][j-1]$描述,其中son表示新加进来的儿子。
另外一种情况,两条边都在son里找。
那么就是在son原有的二元组的基础上再加上i--son这一条,因为现在还需要一条长度为j的边,那么在连接之前son则需要一条长度为j+1的边。

这部分可以用$g[son][j+1]$描述
那么,整个g的方程就呼之欲出了
$g[i][j]+=g[son][j+1]+f[i][j]*f[son][j-1]$
f的方程也很简单
$f[id][j]+=f[to][j-1]$
答案的统计也可以列出来了:
$ans+=g[to][j+1]*f[id][j]+g[id][j]*f[to][j-1]$
注意下图的情况要单独计算

可以用$ans+=g[id][0]$来统计。
注意先转移g,再f,再ans。
优化转移
但是,这样仍然是$O(n^2)$
注意到第一个儿子对i的贡献是
$f[x][i]+=f[u][i-1]$
$ g[x][i]+=g[u][i+1]$
也就是说第一个儿子可以$O(1)$转移。
那么我们可以长链剖分,把最深的儿子放在第一位(每个儿子转移复杂度为他的链的长度)
这样,每条链只会在顶部被计算一次,则复杂度为$O(n)$
剩下的就是空间复杂度的问题,我们可以像重链剖分一样把每条链的节点放在一起,因为f和g数组是直接在深儿子的数组上位移而来。所以每条链可以共用一段空间,只不过每个点的起始位不同。对于

存储如下

这样,空间也是$O(n)$的。这样就可以开始敲代码了。
代码
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
#define N 1000001
#define int long long
vector<int> vec[N];
int maxd[N],dson[N],dep[N];
int g[N*5],f[N*5],dfn[N],cnt,top[N],pos[N],cnt2;
void get_maxd(int id,int from)
{
maxd[id]=1;
dep[id]=dep[from]+1;
for(int i=0;i<vec[id].size();i++)
{
int to=vec[id][i];
if(to==from) continue;
get_maxd(to,id);
if(maxd[to]+1>maxd[id]) dson[id]=to;
maxd[id]=max(maxd[to]+1,maxd[id]);
}
}
void get_dfn(int id,int from,int root)
{
dfn[id]=++cnt;
top[id]=root;
if(dson[id]) get_dfn(dson[id],id,root);
for(int i=0;i<vec[id].size();i++)
{
int to=vec[id][i];
if(to==from||to==dson[id]) continue;
pos[to]=cnt2+maxd[to];
cnt2+=maxd[to]*2;
get_dfn(to,id,to);
}
}
int ans;
#define f(i,j) f[dfn[i]+j]
#define g(i,j) g[pos[top[i]]-dep[i]+dep[top[i]]+j]
void solve(int id,int from)
{
f(id,0)=1;
int tot=0;
if(dson[id]) solve(dson[id],id);
for(int i=0;i<vec[id].size();i++)
{
int to=vec[id][i];
if(to==from||to==dson[id]) continue;
solve(to,id);
g(id,0)+=g(to,1);
for(int j=1;j<=maxd[to];j++)
{
tot+=(j<maxd[to])*g(to,j+1)*f(id,j)+g(id,j)*f(to,j-1);
g(id,j)+=(j<maxd[to])*g(to,j+1)+f(id,j)*f(to,j-1);
f(id,j)+=f(to,j-1);
}
}
tot+=g(id,0);
//cout<<id<<" find: "<<tot<<endl;
ans+=tot;
}
signed main()
{
int n;
//freopen("data.txt","r",stdin);
cin>>n;
for(int i=1;i<n;i++)
{
int a,b;
scanf("%lld%lld",&a,&b);
vec[a].push_back(b);
vec[b].push_back(a);
}
get_maxd(1,0);
pos[1]=maxd[1];
cnt2=maxd[1]*2;
get_dfn(1,0,1);
solve(1,0);
cout<<ans;
}
来源:https://www.cnblogs.com/linzhuohang/p/12268975.html