1.求<=k点对数,容斥法

/*
求树中距离不超过k的点对数
暴力枚举两点,lca求的复杂度是O(n^2logn),这样很多次询问都是冗余的
那么选择重心作为根,问题分成两部分,求经过重心的距离<=k的点对+不经过重心的距离<=k的点对
先来求第一部分,计算所有点的深度,排序,O(nlogn)可以计算出距离<=k的过重心点对
但是这样还不是正确答案,因为还要容斥掉来自同一棵子树的非法点对,那么对这部分再算一次即可
再求第二部分,这部分其实等价于原问题的子问题,所以我们再去重心的每个子树里找重心,和上面一样求
如果一个点已经被当过重心了,那么给它打个vis标记,之后不再访问
这样最多递归O(logn) 次,所以总复杂度是O(n*logn*logn)
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 10005
using namespace std;
struct Edge{int to,nxt,w;}e[N<<1];
int head[N],tot,n,k,ans;
void add(int u,int v,int w){
e[tot].to=v;e[tot].w=w;e[tot].nxt=head[u];head[u]=tot++;
}
int vis[N],size[N],f[N],root,sum;
void getsize(int u,int pre){
size[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre||vis[v])continue;
getsize(v,u);
size[u]+=size[v];
}
}
void getroot(int u,int pre){
f[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre||vis[v])continue;
getroot(v,u);
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],sum-size[u]);
if(f[u]<f[root])root=u;
}
int o[N],cnt;
void getdeep(int u,int pre,int dep){
o[++cnt]=dep;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre||vis[v])continue;
getdeep(v,u,dep+e[i].w);
}
}
int calc(int u,int dep){
cnt=0;
getdeep(u,u,dep);
sort(o+1,o+cnt+1);
int l=1,r=cnt,res=0;
while(l<r){
if(o[l]+o[r]<=k)res+=r-l,l++;
else r--;
}
return res;
}
void solve(int u){
ans+=calc(u,0);vis[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
ans-=calc(v,e[i].w);
sum=size[v];root=0;
getsize(v,0);getroot(v,0);
solve(root);
}
}
void init(){
tot=ans=0;
memset(head,-1,sizeof head);
memset(size,0,sizeof size);
memset(vis,0,sizeof vis);
}
int main(){
while(cin>>n>>k,n){
init();
for(int i=1;i<n;i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
f[0]=n;
sum=n;root=0;
getsize(1,0);getroot(1,0);
solve(root);
cout<<ans<<'\n';
}
}
2.求=k点对数,容斥法

/*
给定一棵边权树,每次给定一个询问x:长度为x的路径是否存在
*/
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define N 10005
struct Eedge{int to,nxt,w;}e[N<<1];
int head[N],tot,n,m;
void add(int u,int v,int w){
e[tot].to=v;e[tot].w=w;e[tot].nxt=head[u];head[u]=tot++;
}
int ans,x,root,size[N],f[N],sum,vis[N];
void getsize(int u,int pre){
size[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre || vis[v])continue;
getsize(v,u);
size[u]+=size[v];
}
}
void getroot(int u,int pre){
f[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre||vis[v])continue;
getroot(v,u);
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],sum-size[u]);
if(f[root]>f[u])root=u;
}
int o[N],cnt;
void getdeep(int u,int pre,int dep){
o[++cnt]=dep;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==pre || vis[v])continue;
getdeep(v,u,dep+e[i].w);
}
}
int calc(int u,int dep){
cnt=0;
getdeep(u,0,dep);
sort(o+1,o+1+cnt);
int res=0,l=1,r=cnt;
while(l<r){
if(o[l]+o[r]==x){//这里要特别处理一下
if(o[l]==o[r]){
res+=(r-l+1)*(r-l)/2;
break;
}
int p=l,q=r;
while(o[p]==o[l])p++;
while(o[q]==o[r])q--;
res+=(p-l)*(r-q);
l=p;r=q;
}
else if(o[l]+o[r]<x)l++;
else r--;
}
return res;
}
void solve(int u){
ans+=calc(u,0);vis[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(vis[v])continue;
ans-=calc(v,e[i].w);
sum=size[v];root=0;
getsize(v,0);getroot(v,0);
solve(root);
}
}
void init(){
tot=0;
memset(head,-1,sizeof head);
}
int main(){
while(cin>>n,n){
init();
for(int u=1;u<=n;u++){
int v,w;
while(scanf("%d",&v),v){
scanf("%d",&w);
add(u,v,w);add(v,u,w);
}
}
while(scanf("%d",&x),x){
ans=0;
memset(vis,0,sizeof vis);
memset(size,0,sizeof size);
memset(f,0,sizeof f);
f[0]=n;
sum=n;root=0;
getsize(1,0);getroot(1,0);
solve(root);
if(ans)cout<<"AYE\n";
else cout<<"NAY\n";
}
puts(".");
}
}
3.求%1e6+3=k的对数,开桶记录

#include<bits/stdc++.h>
#pragma comment(linker,"/STACK:102400000,102400000")
#include<vector>
using namespace std;
#define N 1000005
#define mod 1000003
#define ll long long
#define INF 0x3f3f3f3f
int inv[N];
void init(){
inv[1]=1;
for(int i=2;i<mod;i++)
inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod,inv[i]=(inv[i]+2*mod)%mod;
}
vector<int>G[N];
int n,k;
ll a[N];
pair<int,int>ans;
int f[N],sum,root,size[N],cnt,vis[N];
struct Node{ll id,val;}o[N];
ll flag[N],tag,id[N];
void update(int a,int b){
if(a>b)swap(a,b);
ans=min(ans,make_pair(a,b));
}
void getsize(int u,int pre){
size[u]=1;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==pre||vis[v])continue;
getsize(v,u);
size[u]+=size[v];
}
}
void getroot(int u,int pre){
f[u]=1;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(vis[v] || v==pre)continue;
getroot(v,u);
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],sum-size[u]);
if(f[root]>f[u])root=u;
}
void getdeep(int u,int pre,ll dep){
o[++cnt].val=dep;o[cnt].id=u;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v==pre||vis[v])continue;
getdeep(v,u,dep*a[v]%mod);
}
}
void solve(int u){
++tag;vis[u]=1;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(vis[v])continue;
cnt=0;
getdeep(v,0,a[v]);
for(int j=1;j<=cnt;j++){
Node cur=o[j];
if(cur.val*a[u]%mod==k)
update(u,cur.id);
ll tmp=1ll*k*inv[cur.val*a[u]%mod]%mod;
if(flag[tmp]==tag)
update(id[tmp],cur.id);
}
for(int j=1;j<=cnt;j++){
Node cur=o[j];
if(flag[cur.val]!=tag||id[cur.val]>cur.id)
flag[cur.val]=tag,id[cur.val]=cur.id;
}
}
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(vis[v])continue;
sum=size[v];root=0;
getsize(v,0);getroot(v,0);
solve(root);
}
}
void clear(){
tag=0;
for(int i=1;i<=n;i++)G[i].clear();
memset(vis,0,sizeof vis);
ans.first=ans.second=INF;
memset(flag,0,sizeof flag);
memset(id,0,sizeof id);
}
int main(){
init();
while(cin>>n>>k){
clear();
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<n;i++){
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
f[0]=n;
sum=n;root=0;
getsize(1,0);getroot(1,0);
solve(root);
if(ans.first==INF)
puts("No solution");
else cout<<ans.first<<" "<<ans.second<<'\n';
}
}
/*
5 1000001
1000002 1 1 1 2
1 2
2 3
3 4
4 5
*/
