#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; const int maxn = 1e6; int col[maxn]; int num[maxn]; int vis[maxn]; int idx,h[maxn],e[maxn],ne[maxn]; int tl[maxn],tr[maxn]; struct query{ int l,r,id,k; }a[maxn]; void add(int a,int b){ e[idx]=b; ne[idx]=h[a]; h[a]=idx++; } void dfs(int u,int fa){ //子树对应的dfs序的左区间 tl[u]=++idx; num[idx]=col[u]; for(int i=h[u];~i;i=ne[i]){ if(e[i]!=fa) dfs(e[i],u); } //子树对应的dfs序的右区间 tr[u]=idx; } int sz; bool cmp(query a,query b){ if(a.l/sz!=b.l/sz) return a.l<b.l; return a.r<b.r; } int lowbit(int x){ return x&-x; } int tree[maxn]; int query(int pos){ int ans=0; for(int i=pos;i;i-=lowbit(i)) ans+=tree[i]; return ans; } void update(int pos,int val){ if(pos<=0) return; for(int i=pos;i<maxn;i+=lowbit(i)) tree[i]+=val; } int ans[maxn]; int main() { memset(h,-1,sizeof h); int n,m; scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",&col[i]); for(int i=1;i<n;i++){ int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } idx=0; dfs(1,-1); sz=sqrt(idx); for(int i=1;i<=m;i++){ int x,y; scanf("%d%d",&x,&y); a[i].l=tl[x]; a[i].r=tr[x]; a[i].k=y; a[i].id=i; } sort(a+1,a+1+m,cmp); int L=1,R=0; for(int i=1;i<=m;i++){ while(L<a[i].l){ update(vis[num[L]],-1); vis[num[L]]--; update(vis[num[L]],1); L++; } while(R<a[i].r){ R++; update(vis[num[R]],-1); vis[num[R]]++; update(vis[num[R]],1); } while(L>a[i].l){ L--; update(vis[num[L]],-1); vis[num[L]]++; update(vis[num[L]],1); } while(R>a[i].r){ update(vis[num[R]],-1); vis[num[R]]--; update(vis[num[R]],1); R--; } ans[a[i].id]=query(maxn-1)-query(a[i].k-1); } for(int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }
来源:https://www.cnblogs.com/QingyuYYYYY/p/12380934.html