线段树合并
应用范围:将子树的信息合并给父亲节点,并且权值线段树的下标值域和节点数相近。
CF600E Lomsat gelral
题意:一棵树有n个结点,每个结点都是一种颜色,每个颜色有一个编号,求树中每个子树的最多的颜色编号的和。
\(1 <= n<=1e5\)
解法:线段树合并,这个东西的时空复杂度都很玄学,姑且认为时间为\(O(nlogn)\),空间为(常数\(\times log_n\times n\)),常数一般为\(4-8\)。
#include <cstdio> #include <algorithm> using namespace std; #define maxn 100100 #define ll long long int n; int fir[maxn], nxt[maxn * 2], vv[maxn * 2]; int tot = 0; void add(int u, int v) { nxt[++tot] = fir[u]; fir[u] = tot; vv[tot] = v; } int cnt = 0; int root[maxn], col[maxn]; int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2]; ll ans[maxn * 17 * 2]; void pushup(int a) { if(sum[lz[a]] < sum[rz[a]]) { sum[a] = sum[rz[a]]; ans[a] = ans[rz[a]]; } if(sum[lz[a]] > sum[rz[a]]) { sum[a] = sum[lz[a]]; ans[a] = ans[lz[a]]; } if(sum[lz[a]] == sum[rz[a]]) { sum[a] = sum[lz[a]]; ans[a] = ans[lz[a]] + ans[rz[a]]; } return; } int merge(int a, int b, int l, int r) { if(a == 0) return b; if(b == 0) return a; if(l == r) { sum[a] += sum[b]; ans[a] = l; return a; } int mid = (l + r) >> 1; lz[a] = merge(lz[a], lz[b], l, mid); rz[a] = merge(rz[a], rz[b], mid + 1, r); pushup(a); return a; } void update(int &a, int l, int r, int v) { if(!a) a = ++cnt; int mid = (l + r) >> 1; if(l == r) { sum[a] += 1; ans[a] = l; return; } if(mid >= v) update(lz[a], l, mid, v); if(mid < v) update(rz[a], mid + 1, r, v); pushup(a); return; } void dfs(int u, int fa) { for(int i = fir[u]; i; i = nxt[i]) { int v = vv[i]; if(v == fa) continue; dfs(v, u); merge(root[u], root[v], 1, 100000); } update(root[u], 1, 100000, col[u]); ans[u] = ans[root[u]]; } int main() { scanf("%d", &n); cnt = n; for(int i = 1; i <= n; i++) { scanf("%d", &col[i]); root[i] = i; } for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0); for(int i = 1; i <= n; i++) printf("%lld ", ans[i]); return 0; }
雨天的尾巴
注意\(ans\)要在\(dfs\)时计算,不然当前节点的\(root\)可能被父亲节点继承,然后就炸了。
#include <cstdio> #include <algorithm> using namespace std; #define maxn 100100 #define ll long long int n; int fir[maxn], nxt[maxn * 2], vv[maxn * 2]; int tot = 0; void add(int u, int v) { nxt[++tot] = fir[u]; fir[u] = tot; vv[tot] = v; } int cnt = 0; int root[maxn], col[maxn]; int lz[maxn * 17 * 2], rz[maxn * 17 * 2], sum[maxn * 17 * 2]; ll ans[maxn * 17 * 2]; void pushup(int a) { if(sum[lz[a]] < sum[rz[a]]) { sum[a] = sum[rz[a]]; ans[a] = ans[rz[a]]; } if(sum[lz[a]] > sum[rz[a]]) { sum[a] = sum[lz[a]]; ans[a] = ans[lz[a]]; } if(sum[lz[a]] == sum[rz[a]]) { sum[a] = sum[lz[a]]; ans[a] = ans[lz[a]] + ans[rz[a]]; } return; } int merge(int a, int b, int l, int r) { if(a == 0) return b; if(b == 0) return a; if(l == r) { sum[a] += sum[b]; ans[a] = l; return a; } int mid = (l + r) >> 1; lz[a] = merge(lz[a], lz[b], l, mid); rz[a] = merge(rz[a], rz[b], mid + 1, r); pushup(a); return a; } void update(int &a, int l, int r, int v) { if(!a) a = ++cnt; int mid = (l + r) >> 1; if(l == r) { sum[a] += 1; ans[a] = l; return; } if(mid >= v) update(lz[a], l, mid, v); if(mid < v) update(rz[a], mid + 1, r, v); pushup(a); return; } void dfs(int u, int fa) { for(int i = fir[u]; i; i = nxt[i]) { int v = vv[i]; if(v == fa) continue; dfs(v, u); merge(root[u], root[v], 1, 100000); } update(root[u], 1, 100000, col[u]); ans[u] = ans[root[u]]; } int main() { scanf("%d", &n); cnt = n; for(int i = 1; i <= n; i++) { scanf("%d", &col[i]); root[i] = i; } for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0); for(int i = 1; i <= n; i++) printf("%lld ", ans[i]); return 0; }