题目描述
题解
吼题但题解怎么这么迷
考虑一种和题解不同的做法(理解)
先把僵尸离散化,h相同的钦(ying)点一个大小
(可以发现这样每种情况只会被算正好一次)
计算完全被占领的方案,然后1-方案/概率
由于大小确定了,所以最后会被分成若干不相连的块,且块中有且仅有一只僵尸
定义一个块的编号为所占领僵尸的编号
设f[i][x](x>0)表示以i为根的子树中点i所在块的编号为x
那么对于f[j][y](j∈son[i],j>0)转移如下:
①x=y
f[j][y](僵尸x经过i--j的方案数)-->f[i][x]
那么x和y在同一个块中,因为一个块只有一只僵尸,所以块内必须要连通
②x<y
f[j][y](僵尸y不经过i--j的方案数)-->f[i][x]
x和y不在同一个块中,所以x和y不能连通,即较大的僵尸(y)不能走到另一个点(i)
并且要保证j中存在y,不存在x,原因见下文
③x>y
f[j][y]*(僵尸x不经过i--j的方案数)-->f[i][x]
原因&范围同上
初值为f[i][x]=[x>=h[i]](x>0)
对于②③的限制:
因为要保证以某个点i为最浅点的块内刚好存在僵尸x,
在i与fa[i]断开时保证了x在i的子树中,i所在块的叶子与块中叶子的儿子断开保证了x不在块外,所以块中必定存在x
时间复杂度O(n^3),前后缀优化成O(n^2)
code
#include <algorithm> #include <iostream> #include <cstdlib> #include <cstring> #include <bitset> #include <cstdio> #define fo(a,b,c) for (a=b; a<=c; a++) #define fd(a,b,c) for (a=b; a>=c; a--) #define add(a,b) a=((a)+(b))%998244353 #define min(a,b) (a<b?a:b) #define max(a,b) (a>b?a:b) #define mod 998244353 #define Mod 998244351 using namespace std; struct type{ int x,id; } b[2001]; int a[4002][2]; int c[2001][2001]; int C[2001]; int ls[2001]; int L[2001]; int R[2001]; int h[2001]; int H[2001]; long long f[2001][2001]; long long s1[2002]; long long s2[2002]; bitset<2001> bz[2001]; int T,N,n,m,i,j,k,l,len; long long ans,s; bool cmp(type a,type b) { return a.x<b.x; } void New(int x,int y) { ++len; a[len][0]=y; a[len][1]=ls[x]; ls[x]=len; } long long qpower(long long a,int b) { long long ans=1; while (b) { if (b&1) ans=ans*a%mod; a=a*a%mod; b>>=1; } return ans; } void Dfs(int Fa,int t) { int i; if (h[t]) bz[t][h[t]]=1; for (i=ls[t]; i; i=a[i][1]) if (a[i][0]!=Fa) { Dfs(t,a[i][0]); bz[t]|=bz[a[i][0]]; } } void dfs(int Fa,int t) { int i,j,k,l,id; long long x; fo(i,max(1,h[t]),N) f[t][i]=1; for (i=ls[t]; i; i=a[i][1]) if (a[i][0]!=Fa) { id=i/2; dfs(t,a[i][0]); fo(k,1,N) { s1[k]=s1[k-1]; if (bz[a[i][0]][k]) add(s1[k],f[a[i][0]][k]); } s2[N+1]=0; fd(k,N,1) { s2[k]=s2[k+1]; if (bz[a[i][0]][k]) add(s2[k],f[a[i][0]][k]*max(R[id]-max(H[k],L[id])+1,0)%mod); } fo(j,1,N) { if (!bz[a[i][0]][j]) f[t][j]=f[t][j]*(s2[j+1]+s1[j-1]*max(R[id]-max(H[j],L[id])+1,0)%mod+f[a[i][0]][j]*max(min(H[j]-1,R[id])-L[id]+1,0)%mod)%mod; else f[t][j]=f[t][j]*(f[a[i][0]][j]*max(min(H[j]-1,R[id])-L[id]+1,0)%mod)%mod; // O(n^3) // fo(k,1,N) // if (f[a[i][0]][k]) // { // if (j<k) // x=max(R[id]-max(H[k],L[id])+1,0); // if (j==k) // x=max(min(H[k]-1,R[id])-L[id]+1,0); // if (j>k) // x=max(R[id]-max(H[j],L[id])+1,0); // // if (j==k || bz[a[i][0]][k] && !bz[a[i][0]][j]) // add(F[j],f[t][j]*f[a[i][0]][k]%mod*x); // } } } } int main() { freopen("zombie.in","r",stdin); freopen("zombie.out","w",stdout); scanf("%d",&T); for (;T;--T) { memset(bz,0,sizeof(bz)); memset(ls,0,sizeof(ls)); memset(h,0,sizeof(h)); memset(H,0,sizeof(H)); memset(f,0,sizeof(f)); memset(C,0,sizeof(C)); len=1; scanf("%d%d",&n,&m); fo(i,1,n-1) { scanf("%d%d%d%d",&j,&k,&L[i],&R[i]); New(j,k); New(k,j); } fo(i,1,m) { scanf("%d%d",&j,&k); h[j]=max(h[j],k); } N=0; fo(i,1,n) if (h[i]) b[++N]={h[i],i}; sort(b+1,b+N+1,cmp); fo(i,1,N) { H[i]=b[i].x; h[b[i].id]=i; } Dfs(0,1); dfs(0,1); ans=0; fo(i,1,N) add(ans,f[1][i]); s=1; fo(i,1,n-1) s=s*(R[i]-L[i]+1)%mod; ans=ans*qpower(s,Mod)%mod; printf("%lld\n",((1-ans)%mod+mod)%mod); } fclose(stdin); fclose(stdout); return 0; }