算法:min-max 容斥、树上背包、NTT。
题意简述
有一棵 \(n\) 个点的树。一开始所有点都是白色,每次操作会随机选择 \(\frac{n \times (n + 1)}{2}\) 条路径中的一条,将路径上所有点染黑。求所有点都被染黑的期望操作数。
\(n \le 50\)。多组数据。对 \(998, 244, 353\) 取模。
题解
套路性地,我们使用 min-max 容斥。
\[
\begin{aligned}E(\max(U)) &= \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} E(\min(S)) \\ &= \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} \frac{1}{P(路径经过\ S\ 内的点)} \\ &= \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} \frac{1}{1 - P(路径不经过\ S\ 内的点)}\end{aligned}
\]
如果我们把树画出来,并将 \(S\) 内的点标记,可以发现,这些点把原树分成了若干个连通块,而每个连通块内部可以任意选取路径,能够保证该路径不经过 \(S\) 内的点。而一旦路径跨越连通块,那么一定经过 \(S\) 内的点。
记共分成了 \(x\) 个连通块 \(T_1, T_2, \cdots, T_x\),则不经过 \(S\) 的路径数为
\[
\sum_{i = 1} ^ {x} \frac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}
\]
那么概率为
\[
\frac{\sum_{i = 1} ^ {x} \lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{n \times (n + 1)}
\]
继续化简原式,得
\[
\begin{aligned}E(\max(U)) &= \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} \frac{1}{1 - \dfrac{\sum_{i = 1} ^ {x} \lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{n \times (n + 1)}} \\ &= \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} \frac{n \times (n + 1)}{n \times (n + 1) - \sum_{i = 1} ^ {x} \lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)} \\ &= \frac{n \times (n + 1)}{2} \sum_{\phi \neq S \subset U} (-1) ^ {\lvert S \rvert - 1} \frac{1}{\dfrac{n \times (n + 1)}{2} - \sum_{i = 1} ^ {x} \dfrac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}}\end{aligned}
\]
对于最终情况,一个划分方案给答案带来的贡献仅仅是 \((-1) ^ {|S| - 1}\) 和 \(\sum_{i = 1} ^ x \dfrac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}\) 两块。为了获得这些信息,我们只关心于选点个数为奇数、偶数的方案中,有多少个方案的 \(\sum_{i = 1} ^ x \dfrac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}\) 为我枚举的常数 \(k\)。
考虑设计一个动态规划来解决这个问题。我们先选取 \(1\) 号点作为全树的根,把该树变成一棵有根树。
记 \(f[i][odd][j][k]\) 表示只考虑了以 \(i\) 为根的子树,选了奇数/偶数个点,包含 \(i\) 的连通块大小为 \(j\),\(\sum_{i = 1} ^ x \dfrac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}\)(不含 \(i\) 所在连通块)的值为 \(k\) 的方案数。其中 \(j = 0\) 表示 \(i\) 号点不在连通块中(即在所选的点集中)。
此外,我们记一个数组 \(g[i][odd][k]\) 表示以 \(i\) 为根的子树,选了奇数/偶数个点,\(\sum_{i = 1} ^ x \dfrac{\lvert T_i \rvert \times \left( \lvert T_i \rvert + 1 \right)}{2}\)(含 \(i\) 所在连通块)的值为 \(k\) 的方案数。显然有
\[
g[i][odd][k'] = \sum_{j = 0} ^ n \sum_{k = 0}^{\frac{n \times (n + 1)}{2}} \left[ \frac{j \times (j + 1)}{2} + k == k' \right] f[i][odd][j][k]
\]
转移类似树上背包。合并 \(i\) 与一个儿子 \(v\) 时,有转移式:
\[
f[i][odd][j][k] = \left\{\begin{aligned} &\sum_{odd' \in \{0, 1\}} \sum_{k' = 0} ^ {k} f[i][odd][j][k'] \cdot g[v][odd \oplus odd'][k - k'] \quad & j = 0 \\ &\sum_{odd' \in \{0, 1\}} \sum_{j' = 1} ^ {j} \sum_{k' = 0}^{k} f[i][odd][j][k] \cdot f[v][odd \oplus odd'][j - j'][k - k'] \quad & j \neq 0\end{aligned}\right.
\]
暴力转移复杂度是 \(\mathcal{O}(n^7)\) 的,并不可接受。
我们留意到 \(j\) 一维的大小是不超过 \(i\) 的子树大小的,这是树上背包的经典形式。因此复杂度就少了一个 \(n\),变成了 \(\mathcal{O}(n^6)\)。
继续观察,发现 \(k\) 一维的转移是一个卷积形式,可以使用 NTT 进行优化。换句话说,除了把 \(f[i]\) 向 \(g[i]\) 算贡献的地方,如果我们把状态的最后一维看成一个 \(\frac{n \times (n + 1)}{2}\) 次的多项式,那么这里的所有运算都可以看成多项式加法和多项式乘法。因此我们可以在一开始就用点值表示法表示 \(f\) 和 \(g\) 的值,只有在 \(f\) 向 \(g\) 算贡献的时候,我们才进行一次 INTT,把点值表示法还原回系数表示法,得到 \(g\) 后,就又可以还原成点值表示法。
于是原本 \(\mathcal{O}(n^2)\) 的转移被优化到了 \(\mathcal{O}(\log n)\)。这样以后总复杂度变为 \(\mathcal{O}(n ^ 4 \log n)\)。
测试的时候我的程序每组数据要跑 \(0.5 \text{ s}\),这题的数据组数 \(T = 15\) 以后我就被卡了常数。目前只能过前 \(9\) 个点,最后 \(3\) 个点仍在努力卡常中。
#include <algorithm> #include <cstdio> #include <cstring> const int MaxN = 50 + 5; const int MaxV = 4096 + 5; const int Mod = 998244353, Prt = 3; struct Graph { int cnte; int Head[MaxN], To[MaxN * 2], Next[MaxN * 2]; inline void clear() { cnte = 0; memset(Head, 0, sizeof Head); memset(To, 0, sizeof To); memset(Next, 0, sizeof Next); } inline void addEdge(int from, int to) { cnte++; To[cnte] = to; Next[cnte] = Head[from]; Head[from] = cnte; } }; int Te, N; int Fa[MaxN], Siz[MaxN]; int F[MaxN][2][MaxN][MaxV], G[MaxN][2][MaxV]; int Rev[MaxV][MaxV], W[2][MaxV], Inv[MaxV]; Graph Gr; inline int add(int x, int y) { return (x += y) >= Mod ? x - Mod : x; } inline int sub(int x, int y) { return (x -= y) < 0 ? x + Mod : x; } inline int mul(int x, int y) { return 1LL * x * y % Mod; } inline int pw(int x, int y) { int z = 1; for (; y; y >>= 1, x = mul(x, x)) if (y & 1) z = mul(z, x); return z; } inline int inv(int x) { return pw(x, Mod - 2); } inline int sep(int x, int y) { return mul(x, inv(y)); } inline void inc(int &x, int y = 1) { x = add(x, y); } inline void dec(int &x, int y = 1) { x = sub(x, y); } void init() { scanf("%d", &N); for (int i = 1; i < N; ++i) { int u, v; scanf("%d %d", &u, &v); Gr.addEdge(u, v); Gr.addEdge(v, u); } } void dfs1(int u) { Siz[u] = 1; for (int i = Gr.Head[u]; i; i = Gr.Next[i]) { int v = Gr.To[i]; if (v == Fa[u]) continue; Fa[v] = u; dfs1(v); Siz[u] += Siz[v]; } } inline void ntt(int *a, int n, int f) { for (int i = 1; i < n; ++i) if (i < Rev[n][i]) std::swap(a[i], a[Rev[n][i]]); for (int i = 1; i < n; i <<= 1) { int w = W[f][i]; for (int j = 0; j < n; j += (i << 1)) { int x = 1; for (int k = 0; k < i; ++k, x = mul(x, w)) { int lson = a[j + k], rson = a[i + j + k]; a[j + k] = add(lson, mul(rson, x)); a[i + j + k] = sub(lson, mul(rson, x)); } } } if (f == 1) for (int i = 0; i < n; ++i) a[i] = mul(a[i], Inv[n]); } inline int getPow2(int n) { int v = 1; while (v < n) v <<= 1; return v; } inline void calcG(int u) { for (int odd = 0; odd <= 1; ++odd) { for (int j = 0; j <= N; ++j) for (int k = 0; k <= N * (N + 1) / 2; ++k) { int newK = j * (j + 1) / 2 + k; if (newK > N * (N + 1) / 2) break; inc(G[u][odd][newK], F[u][odd][j][k]); } } } void dfs2(int u) { F[u][1][0][0] = F[u][0][1][0] = 1; int sz = 1; for (int i = Gr.Head[u]; i; i = Gr.Next[i]) { int v = Gr.To[i]; if (v == Fa[u]) continue; dfs2(v); sz += Siz[v]; static int f[2][MaxN][MaxV]; int len = getPow2(Siz[u] * (Siz[u] + 1)); for (int odd = 0; odd <= 1; ++odd) for (int j = 0; j <= sz; ++j) for (int k = 0; k < len; ++k) f[odd][j][k] = 0; for (int odd = 0; odd <= 1; ++odd) { ntt(G[v][odd], len, 0); for (int j = 0; j <= sz - Siz[v]; ++j) ntt(F[u][odd][j], len, 0); for (int j = 0; j <= Siz[v]; ++j) ntt(F[v][odd][j], len, 0); } for (int odd = 0; odd <= 1; ++odd) { for (int j = 0; j <= sz; ++j) { for (int odd2 = 0; odd2 <= 1; ++odd2) { if (j == 0) { for (int k = 0; k < len; ++k) inc(f[odd][j][k], mul(F[u][odd ^ odd2][j][k], G[v][odd2][k])); } else { for (int j2 = std::max(0, j - sz + Siz[v]); j2 <= Siz[v] && j2 < j; ++j2) { for (int k = 0; k < len; ++k) inc(f[odd][j][k], mul(F[u][odd ^ odd2][j - j2][k], F[v][odd2][j2][k])); } } } } } for (int odd = 0; odd <= 1; ++odd) { ntt(G[v][odd], len, 1); for (int j = 0; j <= sz - Siz[v]; ++j) ntt(F[u][odd][j], len, 1); for (int j = 0; j <= Siz[v]; ++j) ntt(F[v][odd][j], len, 1); for (int j = 0; j <= sz; ++j) ntt(f[odd][j], len, 1); } for (int odd = 0; odd <= 1; ++odd) for (int j = 0; j <= sz; ++j) for (int k = 0; k <= sz * (sz + 1) / 2; ++k) F[u][odd][j][k] = f[odd][j][k]; } calcG(u); } void solve() { dfs1(1); dfs2(1); int ans = 0; for (int odd = 0; odd <= 1; ++odd) for (int k = 0; k < N * (N + 1) / 2; ++k) { if (odd == 1) inc(ans, mul(G[1][odd][k], inv(N * (N + 1) / 2 - k))); else dec(ans, mul(G[1][odd][k], inv(N * (N + 1) / 2 - k))); } ans = mul(ans, N * (N + 1) / 2); printf("%d\n", ans); } void clear() { memset(Fa, 0, sizeof Fa); memset(Siz, 0, sizeof Siz); memset(F, 0, sizeof F); memset(G, 0, sizeof G); Gr.clear(); } int main() { for (int i = 1; i <= 4096; i <<= 1) { Rev[i][0] = 0; for (int j = 1; j < i; ++j) { Rev[i][j] = (Rev[i][j >> 1]) >> 1; if (j & 1) Rev[i][j] |= (i >> 1); } Inv[i] = inv(i); } for (int f = 0; f <= 1; ++f) for (int i = 1; i < 4096; i <<= 1) W[f][i] = pw(pw(Prt, f == 0 ? 1 : Mod - 2), (Mod - 1) / (i << 1)); scanf("%d", &Te); for (int t = 1; t <= Te; ++t) { init(); printf("Case #%d: ", t); solve(); clear(); } return 0; }