题面
解析
求到$S$集合每个点走一次的期望,即求$E(max(S))$,套上$Min-Max$容斥,即是求$E(min(T)),T\subseteq S$
考虑对每种集合做一次$dp$,外层枚举$P \subseteq U$,$dp[u]$表示点$u$到$P$集合内任意一点的期望时间, $deg[u]$表示点$u$的度数,$fa$为$u$的父亲节点,$v$为$u$的儿子节点。
若$u \in P$,则$dp[u] = 0$
否则有:$dp[u] = \frac{1}{deg[u]}(dp[fa]+\sum dp[v])+1$
然后是一个我没总结过的较常见套路,设$dp[u] = A[u] * dp[fa] + B[u]$,带入上式化简:$$deg[u]*dp[u]=dp[fa]+\sum(A[v]*dp[u]+B[v])+deg[u]$$$$(deg[u]-\sum A[v])*dp[u]=dp[fa]+(\sum B[v])+deg[u]$$$$dp[u]=\frac{1}{deg[u]-\sum A[v]}*dp[fa]+\frac{deg[u]+\sum B[v]}{deg[u]-\sum A[v]}$$
故:$$A[u]=\frac{1}{deg[u]-\sum A[v]},\ B[u]=\frac{deg[u]+\sum B[v]}{deg[u]-\sum A[v]}$$
对于根节点,由于其没有父节点,故$dp[u]=B[u]$,也即$B[u]$为所求。
现在可以求出$E(min(T))$,但我们需要求出$\sum_{T\subseteq S}(-1)^{|T|+1}*E(min(T))$,可以发现其实就是求$S$的子集权值和, 可以用$FWT(or)$预处理出所有$S$的答案,每次询问可以做到$O(1)$回答。
因$DP$过程中需要求逆元,故时间复杂度为$DP$的时间复杂度:$O(n2^n \log mod)$
代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn = (1 << 18) + 5, mod = 998244353;
ll qpow(ll x, ll y)
{
ll ret = 1;
while(y)
{
if(y&1)
ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
ll add(ll x, ll y)
{
return x + y < mod? x + y: x + y - mod;
}
ll rdc(ll x, ll y)
{
return x - y < 0? x - y + mod: x - y;
}
int n, m, Q, rt, deg[20], num[maxn];
ll f[maxn], A[20], B[20];
vector<int> G[maxn];
void dfs(int x, int fa, int s)
{
if((s >> (x - 1)) & 1)
{
A[x] = B[x] = 0;
return ;
}
ll s1 = 0, s2 = 0;
for(auto &id: G[x])
{
if(id == fa) continue;
dfs(id, x, s);
s1 = add(s1, A[id]);
s2 = add(s2, B[id]);
}
A[x] = qpow(rdc(deg[x], s1), mod - 2);
B[x] = add(s2, deg[x]) * A[x] % mod;
}
void FWT(ll *x)
{
for(int i = 1; i <= m; i <<= 1)
for(int j = 0; j <= m; j += (i << 1))
for(int k = 0; k < i; ++k)
x[i+j+k] = add(x[i+j+k], x[j+k]);
}
int main()
{
scanf("%d%d%d", &n, &Q, &rt);
int u, v, cnt, sta;
for(int i = 1; i < n; ++i)
{
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
++ deg[u];
++ deg[v];
}
m = (1 << n) - 1;
for(int i = 1; i <= m; ++i)
{
dfs(rt, 0, i);
num[i] = num[i>>1] + (i & 1);
f[i] = ((num[i] & 1)? 1: mod - 1) * B[rt] % mod;
}
FWT(f);
for(int i = 1; i <= Q; ++i)
{
scanf("%d", &cnt);
sta = 0;
for(int j = 1; j <= cnt; ++j)
{
scanf("%d", &u);
sta |= (1 << (u - 1));
}
printf("%lld\n", f[sta]);
}
return 0;
}
来源:https://www.cnblogs.com/Joker-Yza/p/12397263.html
