Address
Solution
考虑对 \(k\in [l-1,r]\) 分别求出有多少个集合 \(S\) 满足 \(w(S)\le k\),记作 \(ans_k\)。
先求出 \(1\) 的初始权值 \(W\)。
记 \(val(x)\) 表示 \(x\) 的权值。枚举 \(k\),现在对于每个叶子 \(u\),如果 \(u\in S\),那么 \(val(u)\in [u-W,u+W]\),否则 \(val(u)=W\)。
我们发现,把叶子节点的权值改成 \(W\) 肯定是不优的。所以改动一些叶子后,如果 \(val(1)\) 还是 \(W\),那么肯定路径 \(1→W\) 上每个点的权值都是 \(W\),且其它的点的权值都不是 \(W\)。
因此,如果想要 \(val(1)\) 改变,那么路径 \(1→W\) 上肯定存在一个点 \(x\),\(val(x)\ne W\)。记 \(x\) 在路径 \(1→W\) 上的子节点为 \(y\)。如果 \(x\) 深度是奇数, 那么肯定存在一个 \(x\) 的子节点 \(z(z\ne y)\),\(val(z)>W\)。\(x\) 深度是偶数时同理。
我们把 \(1→W\) 上的边全部断掉,再求一遍每个点的权值。如果原路径 \(1→W\) 上存在某个深度为奇数的点的权值 \(>W\),或者某个深度为偶数的点的权值 \(<W\),那么 \(val(1)\) 肯定改变,否则肯定不变。
记 \(f(u)\) 表示 \(u\) 子树中,使 \(val(u)>w\) 的合法叶子节点集合有几个。\(g(u)\) 表示 \(u\) 子树中,使 \(val(u)<w\) 的合法叶子节点集合有几个。
如果 \(u\) 是叶子节点:\(f(u)=[u>W]+[u+k>W],g(u)=[u<W]+[u-k<W]\)。其中 \([u>W],[u<W]\) 表示 \(u\) 不在叶子节点集合内,\([u+k>W],[u-k<W]\) 表示在集合内。
如果 \(u\) 是深度为奇数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点最大权值必须 \(>W\),也就是说不能全部 \(\le W\)。因此 \(f(u)=2^{cnt_u}\prod_{v\in son_u}(2^{cnt_v}-f(v))\)。其中 \(cnt_u\) 表示 \(u\) 的子树内有几个叶子节点。
如果 \(u\) 是深度为偶数的非叶子节点,如果 \(val(u)>W\),那么 \(u\) 的子节点全部 \(<W\)。因此 \(f(u)=\prod_{v\in son_u}f(v)\)。
\(g\) 的转移和 \(f\) 类似。
接下来求 \(ans_k\)。考虑补集转化,即用 \(2^{cnt_1}\) 减去不会让 \(val(1)\) 改变的集合数。不会让 \(val(1)\) 改变,就是要让原路径 \(1→W\) 上的每个点的权值都不变。那么把深度为奇数的 \(2^{cnt_x}-f_x\) 和深度为偶数的 \(2^{cnt_x}-g_x\) 全部相乘就是答案了。
至此,我们得到了一个 \(O(n(R-L))\) 的做法。
考虑优化,我们发现转移与 \(k\) 无关,只有叶子节点的 \(f,g\) 和 \(k\) 有关。进一步地,我们发现随着 \(k\) 变大,每个叶子节点的 \(f,g\) 最多改变一次。因此可以看作是 \(O(n)\) 次修改的动态 \(dp\),时间复杂度 \(O(n\log^2 n)\)。
Code
#include <bits/stdc++.h> using namespace std; #define ll long long #define p2 p << 1 #define p3 p << 1 | 1 template <class t> inline void read(t & res) { char ch; while (ch = getchar(), !isdigit(ch)); res = ch ^ 48; while (ch = getchar(), isdigit(ch)) res = res * 10 + (ch ^ 48); } template <class t> inline void print(t x) { if (x > 9) print(x / 10); putchar(x % 10 + 48); } const int e = 2e5 + 5, mod = 998244353; struct point { int x, y; }b[e], que[e]; struct matrix { int a, b; matrix(){} matrix(int _a, int _b) : a(_a), b(_b) {} }tr[e << 2]; vector<int>g[e], c[e], d[e]; int f[e], dep[e], L, R, w, n, fa[e], a[e], m, nxt[e], go[e], adj[e], val[e], K, cnt[e], f2[e]; int q[e], h[e], num, all, sum[e << 2], son[e], sze[e], dfnA[e], dfnB[e], timA, timB, idA[e], idB[e]; int st[e], ed[e], bot[e], top[e], ans[e], rt[e], now_rt; bool is[e], op, bo[e]; inline void add(int &x, int y) { (x += y) >= mod && (x -= mod); } inline void del(int &x, int y) { (x -= y) < 0 && (x += mod); } inline int plu(int x, int y) { add(x, y); return x; } inline int sub(int x, int y) { del(x, y); return x; } inline int mul(int x, int y) { return (ll)x * y % mod; } inline int ksm(int x, int y) { int res = 1; while (y) { if (y & 1) res = mul(res, x); y >>= 1; x = mul(x, x); } return res; } inline matrix operator + (matrix u, matrix v) { return matrix(mul(u.a, v.a), plu(mul(u.b, v.a), v.b)); } inline void link1(int x, int y) { g[x].push_back(y); g[y].push_back(x); } inline void link2(int x, int y) { nxt[++num] = adj[x]; adj[x] = num; go[num] = y; } inline void dfs1(int u, int pa) { dep[u] = dep[pa] + 1; fa[u] = pa; if (dep[u] & 1) val[u] = 0; else val[u] = n + 1; int len = g[u].size(), i; bool pd = 0; for (i = 0; i < len; i++) { int v = g[u][i]; if (v == pa) continue; pd = 1; dfs1(v, u); if (dep[u] & 1) val[u] = max(val[u], val[v]); else val[u] = min(val[u], val[v]); } if (!pd) val[u] = u, all++; } inline void dfs2(int u) { if (val[u] == u) { if (op) { f[u] = (u > w) + (u + K > w); if (L <= w + 1 - u && w + 1 - u <= R) c[w + 1 - u].push_back(u); } else { f[u] = (u < w) + (u - K < w); if (L <= u + 1 - w && u + 1 - w <= R) d[u + 1 - w].push_back(u); } return; } f[u] = f2[u] = 1; bool fl = ((dep[u] & 1) && op) || ((~dep[u] & 1) && !op); bo[u] = fl; for (int i = adj[u]; i; i = nxt[i]) { int v = go[i]; dfs2(v); if (fl) f[u] = mul(f[u], sub(q[v], f[v])); else f[u] = mul(f[u], f[v]); if (v != son[u]) { if (fl) f2[u] = mul(f2[u], sub(q[v], f[v])); else f2[u] = mul(f2[u], f[v]); } } if (fl) f[u] = sub(q[u], f[u]); } inline void dfs3(int u) { if (val[u] == u) cnt[u] = 1; sze[u] = 1; rt[u] = now_rt; for (int i = adj[u]; i; i = nxt[i]) { int v = go[i]; dfs3(v); cnt[u] += cnt[v]; sze[u] += sze[v]; if (sze[v] > sze[son[u]]) son[u] = v; } } inline void dfs4(int u, int fi) { top[u] = fi; dfnA[u] = ++timA; idA[timA] = u; if (son[u]) { dfs4(son[u], fi); st[u] = timB + 1; for (int i = adj[u]; i; i = nxt[i]) { int v = go[i]; if (v == son[u]) continue; dfnB[v] = ++timB; idB[timB] = v; } ed[u] = timB; } for (int i = adj[u]; i; i = nxt[i]) { int v = go[i]; if (v == son[u]) continue; dfs4(v, v); } if (son[u]) bot[u] = bot[son[u]]; else bot[u] = u; } inline void build(int l, int r, int p) { if (l == r) { int u = idA[l], v = idB[l]; if (son[u]) { if (bo[u]) { int v = son[u]; tr[p] = matrix(f2[u], sub(q[u], mul(f2[u], q[v]))); } else tr[p] = matrix(f2[u], 0); } if (v) { int pa = fa[v]; if (bo[pa]) sum[p] = sub(q[v], f[v]); else sum[p] = f[v]; } return; } int mid = l + r >> 1; build(l, mid, p2); build(mid + 1, r, p3); tr[p] = tr[p3] + tr[p2]; sum[p] = mul(sum[p2], sum[p3]); } inline void upt_tr(int l, int r, int s, matrix u, int p) { if (l == r) { tr[p] = u; return; } int mid = l + r >> 1; if (s <= mid) upt_tr(l, mid, s, u, p2); else upt_tr(mid + 1, r, s, u, p3); tr[p] = tr[p3] + tr[p2]; } inline void upt_sum(int l, int r, int s, int v, int p) { if (l == r) { sum[p] = v; return; } int mid = l + r >> 1; if (s <= mid) upt_sum(l, mid, s, v, p2); else upt_sum(mid + 1, r, s, v, p3); sum[p] = mul(sum[p2], sum[p3]); } inline matrix ask_tr(int l, int r, int s, int t, int p) { if (l == s && r == t) return tr[p]; int mid = l + r >> 1; if (t <= mid) return ask_tr(l, mid, s, t, p2); else if (s > mid) return ask_tr(mid + 1, r, s, t, p3); else return ask_tr(mid + 1, r, mid + 1, t, p3) + ask_tr(l, mid, s, mid, p2); } inline int ask_sum(int l, int r, int s, int t, int p) { if (l == s && r == t) return sum[p]; int mid = l + r >> 1; if (t <= mid) return ask_sum(l, mid, s, t, p2); else if (s > mid) return ask_sum(mid + 1, r, s, t, p3); else return mul(ask_sum(l, mid, s, mid, p2), ask_sum(mid + 1, r, mid + 1, t, p3)); } inline void pair_mul(point &u, int x) { if (!x) u.y++; else u.x = mul(u.x, x); } inline void pair_div(point &u, int x) { if (!x) u.y--; else u.x = mul(u.x, ksm(x, mod - 2)); } inline void cover(int &x, point u) { int res = u.x; if (u.y) res = 0; x = sub(all, res); } inline int calc(int x, matrix c) { return plu(mul(x, c.a), c.b); } inline int ask(int x) { if (x == bot[x]) return f[x]; int l = dfnA[x], r = dfnA[bot[x]] - 1; return calc(f[bot[x]], ask_tr(1, n, l, r, 1)); } inline void change(int x) { pair_div(que[K], sub(q[rt[x]], f[rt[x]])); x = top[x]; while (x) { f[x] = ask(x); if (!fa[x]) break; int y = fa[x]; if (bo[y]) upt_sum(1, n, dfnB[x], sub(q[x], f[x]), 1); else upt_sum(1, n, dfnB[x], f[x], 1); f2[y] = ask_sum(1, n, st[y], ed[y], 1); matrix tmp; if (bo[y]) { int v = son[y]; tmp = matrix(f2[y], sub(q[y], mul(f2[y], q[v]))); } else tmp = matrix(f2[y], 0); upt_tr(1, n, dfnA[y], tmp, 1); x = top[y]; } pair_mul(que[K], sub(q[x], f[x])); } int main() { freopen("minimax.in", "r", stdin); freopen("minimax.out", "w", stdout); read(n); read(L); read(R); int i, x, y, j; for (i = 1; i < n; i++) read(x), read(y), link1(x, y), b[i].x = x, b[i].y = y; dfs1(1, 0); x = w = val[1]; h[0] = 1; for (i = 1; i <= n; i++) h[i] = plu(h[i - 1], h[i - 1]); while (x != 1) { a[++m] = x; x = fa[x]; } a[++m] = 1; reverse(a + 1, a + m + 1); for (i = 1; i <= m; i++) is[a[i]] = 1; for (i = 1; i < n; i++) { x = b[i].x; y = b[i].y; if (!is[x] || !is[y]) { if (fa[x] == y) link2(y, x); else link2(x, y); } } for (i = 1; i <= m; i++) now_rt = a[i], dfs3(a[i]); all = h[all]; for (i = 1; i <= n; i++) q[i] = h[cnt[i]]; for (i = 1; i <= m; i++) dfs4(a[i], a[i]), fa[a[i]] = 0; bool flag = 0; if (L == 1) K = L, flag = 1, L++; else K = L - 1; que[K].x = 1; for (j = 1; j <= m; j++) { int u = a[j]; op = j & 1; dfs2(u); pair_mul(que[K], sub(q[u], f[u])); } cover(ans[K], que[K]); build(1, n, 1); for (i = L; i <= R; i++) { que[i] = que[i - 1]; K = i; int lenc = c[i].size(), lend = d[i].size(); for (j = 0; j < lenc; j++) { int u = c[i][j]; f[u] = (u > w) + (u + K > w); change(u); } for (j = 0; j < lend; j++) { int u = d[i][j]; f[u] = (u < w) + (u - K < w); change(u); } if (i == n) { ans[i] = sub(all, 1); continue; } cover(ans[i], que[i]); } if (flag) L--; for (i = L; i <= R; i++) print(sub(ans[i], ans[i - 1])), putchar(i == R ? '\n' : ' '); return 0; }
来源:https://www.cnblogs.com/cyf32768/p/12296954.html