太久没写博客了,过来水一发。
题目链接:洛谷
首先我们想到,考虑每个叶节点的权值为根节点权值的概率。首先要将叶节点权值离散化。
假设现在是$x$节点,令$f_i,g_i$分别表示左/右节点的权值$=i$的概率。
若$w_x$来自于左儿子,则
$$P(w_x=i)=f_i*(p_x*\sum_{j=1}^{i-1}g_j+(1-p)*\sum_{j=i+1}^mg_j)$$
右儿子也是一样的。
所以在转移的时候需要顺便维护$f,g$的前/后缀和。
但是我们发现这样直接跑是$O(n^2)$的,肯定不行,但是每个节点的所有dp值都只依赖于两个儿子,而且区间乘法是可以使用lazy_tag的,所以可以使用线段树合并。
(等会儿,好像之前并没有写过。。。)
线段树合并就是对于值域线段树,合并的时候如果两棵树都有这个节点,那么就递归下去,否则直接按照上面的式子转移。
$f,g$的前/后缀和也可以放在参数里面顺便维护了。

1 #include<bits/stdc++.h>
2 #define Rint register int
3 using namespace std;
4 typedef long long LL;
5 const int N = 300003, mod = 998244353, inv = 796898467;
6 int n, v[N], tot, p[N], fa[N], head[N], to[N], nxt[N];
7 inline void add(int a, int b){
8 static int cnt = 0;
9 to[++ cnt] = b; nxt[cnt] = head[a]; head[a] = cnt;
10 }
11 int root[N], ls[N << 5], rs[N << 5], seg[N << 5], tag[N << 5], cnt, ans;
12 inline void pushdown(int x){
13 if(x && tag[x] != 1){
14 if(ls[x]){
15 seg[ls[x]] = (LL) seg[ls[x]] * tag[x] % mod;
16 tag[ls[x]] = (LL) tag[ls[x]] * tag[x] % mod;
17 }
18 if(rs[x]){
19 seg[rs[x]] = (LL) seg[rs[x]] * tag[x] % mod;
20 tag[rs[x]] = (LL) tag[rs[x]] * tag[x] % mod;
21 }
22 tag[x] = 1;
23 }
24 }
25 inline void change(int &x, int L, int R, int pos){
26 if(!x) tag[x = ++ cnt] = 1;
27 pushdown(x);
28 ++ seg[x];
29 if(seg[x] >= mod) seg[x] = 0;
30 if(L == R) return;
31 int mid = L + R >> 1;
32 if(pos <= mid) change(ls[x], L, mid, pos);
33 else change(rs[x], mid + 1, R, pos);
34 }
35 inline int merge(int lx, int rx, int L, int R, int pl, int pr, int sl, int sr, int P){
36 if(!lx && !rx) return 0;
37 int now = ++ cnt, mid = L + R >> 1; tag[now] = 1;
38 pushdown(lx); pushdown(rx);
39 if(!lx){
40 int v = ((LL) P * sl + (mod + 1ll - P) * sr) % mod;
41 seg[now] = (LL) seg[rx] * v % mod;
42 tag[now] = (LL) tag[rx] * v % mod;
43 ls[now] = ls[rx]; rs[now] = rs[rx];
44 return now;
45 }
46 if(!rx){
47 int v = ((LL) P * pl + (mod + 1ll - P) * pr) % mod;
48 seg[now] = (LL) seg[lx] * v % mod;
49 tag[now] = (LL) tag[lx] * v % mod;
50 ls[now] = ls[lx]; rs[now] = rs[lx];
51 return now;
52 }
53 ls[now] = merge(ls[lx], ls[rx], L, mid, pl, (pr + seg[rs[rx]]) % mod, sl, (sr + seg[rs[lx]]) % mod, P);
54 rs[now] = merge(rs[lx], rs[rx], mid + 1, R, (pl + seg[ls[rx]]) % mod, pr, (sl + seg[ls[lx]]) % mod, sr, P);
55 seg[now] = (seg[ls[now]] + seg[rs[now]]) % mod;
56 return now;
57 }
58 inline void getans(int x, int L, int R){
59 pushdown(x);
60 if(L == R){
61 ans = (ans + (LL) seg[x] * seg[x] % mod * v[L] % mod * L % mod) % mod;
62 return;
63 }
64 int mid = L + R >> 1;
65 getans(ls[x], L, mid);
66 getans(rs[x], mid + 1, R);
67 }
68 inline void dfs(int x){
69 if(!head[x]){
70 change(root[x], 1, n, p[x]);
71 return;
72 }
73 for(Rint i = head[x];i;i = nxt[i]){
74 dfs(to[i]);
75 if(!root[x]) root[x] = root[to[i]];
76 else root[x] = merge(root[x], root[to[i]], 1, n, 0, 0, 0, 0, p[x]);
77 }
78 }
79 int main(){
80 scanf("%d", &n);
81 for(Rint i = 1;i <= n;i ++){
82 scanf("%d", fa + i);
83 if(fa[i]) add(fa[i], i);
84 }
85 for(Rint i = 1;i <= n;i ++){
86 scanf("%d", p + i);
87 if(head[i]) p[i] = (LL) p[i] * inv % mod;
88 else v[++ tot] = p[i];
89 }
90 sort(v + 1, v + tot + 1);
91 for(Rint i = 1;i <= n;i ++)
92 if(!head[i]) p[i] = lower_bound(v + 1, v + tot + 1, p[i]) - v;
93 dfs(1);
94 getans(root[1], 1, n);
95 printf("%d", ans);
96 }
来源:https://www.cnblogs.com/AThousandMoons/p/10893829.html
