题目描述
题解
随便bb
详细题解见
https://www.cnblogs.com/coldchair/p/11624979.html
https://blog.csdn.net/alan_cty/article/details/84557477
https://www.cnblogs.com/Iking123/p/11626041.html
这里讲讲自己发现的东西和一些细节
f[i][p][a]表示第i位以后(包括第i位)的最大值,a表示个位,在第i为进1的个位会变成什么
为什么要包括第i位呢,因为假设的进位不会影响到p,所以考虑上包括的情况都一样
而且可以适应第i位不为0的情况,更严谨一些
g[i][p][x][a]的x表示第i位放x后个位会变成什么,只需要进x次位就行了
当i=1时要特殊考虑(考虑能否放x)
至于f和g的取值是否重复,显然不会不然怎么做
把方程列出来后可以发现结果其实是存在原来的状态上的,所以不会重复不然就是列错了
dp[i][j][p][a]表示dfs序为i,做到第j位的方案数
转移前缀和优化,可以直接把dp[i]设为原来的dp[1..i]
要考虑i=1和j=1
code
#include <algorithm> #include <iostream> #include <cstdlib> #include <cstring> #include <cstdio> #define fo(a,b,c) for (a=b; a<=c; a++) #define fd(a,b,c) for (a=b; a>=c; a--) #define add(a,b) (a=(a+(b)%998244353)%998244353) #define min(a,b) (a<b?a:b) #define max(a,b) (a>b?a:b) #define mod 998244353 using namespace std; struct type{ int x,y; } E[1001]; int a[1001][2]; int ls[501]; int f[501][10][10]; //f[i][p][a] p=max in [>=i] a=the first int g[501][10][10][10]; //g[i][p][x][a] int dp[501][501][10][10]; //dp[i][j][p][a] int st[501]; int d[501]; int fa[501]; int N,A,n,K,I,i,j,k,l,len; long long ans; bool cmp(type a,type b) { return a.x<b.x || a.x==b.x && a.y>b.y; } void New(int x,int y) { ++len; a[len][0]=y; a[len][1]=ls[x]; ls[x]=len; } void dfs(int Fa,int t) { int i; fa[t]=Fa; st[t]=++N; for (i=ls[t]; i; i=a[i][1]) if (a[i][0]!=Fa) dfs(t,a[i][0]); } void Dfs(int Fa,int t) { int i,j,k,l; i=st[t]; fo(j,1,n) { fo(k,0,K-1) { fo(l,0,K-1) dp[i][j][k][l]=dp[i-1][j][k][l]; } } fo(j,1,n) { fo(k,0,K-1) { fo(l,0,K-1) { if (j>1) { if (t>1) add(dp[i][j-1][max(k,d[t])][g[j][k][d[t]][l]],dp[i-1][j][k][l]-dp[st[fa[t]]-1][j][k][l]); else add(dp[i][j-1][max(k,d[t])][g[j][k][d[t]][l]],dp[i-1][j][k][l]); } else if (g[j][k][d[t]][l]>-1) { if (t>1) add(ans,dp[i-1][j][k][l]-dp[st[fa[t]]-1][j][k][l]); else add(ans,dp[i-1][j][k][l]); } } } } for (i=ls[t]; i; i=a[i][1]) if (a[i][0]!=Fa) Dfs(t,a[i][0]); } int main() { freopen("buried.in","r",stdin); freopen("buried.out","w",stdout); scanf("%d%d",&n,&K); fo(i,1,n) scanf("%d",&d[i]); fo(i,2,n) { scanf("%d%d",&j,&k); E[++l]={j,k}; E[++l]={k,j}; } sort(E+1,E+l+1,cmp); fo(i,1,l) New(E[i].x,E[i].y); memset(f,255,sizeof(f)); memset(g,255,sizeof(g)); fo(i,0,K-1) { fo(j,0,K-1) if (i|j) { k=j; while (k<K) k+=max(i,k); f[2][i][j]=k%K; } } fo(i,2,n-1) { fo(j,0,K-1) { fo(k,0,K-1) if (j|k) { l=k; fo(I,0,K-1) if (f[i][max(j,I)][l]>-1) l=f[i][max(j,I)][l]; else break; if (I==K) f[i+1][j][k]=l; } } } fo(j,0,K-1) { fo(l,0,K-1) if (j|l) { A=l; while (A<K) { g[1][j][A][l]=A; A+=max(j,A); } } } fo(i,2,n) { fo(j,0,K-1) { fo(l,0,K-1) if (j|l) { A=l; fo(k,0,K-1) if (A>-1) { g[i][j][k][l]=A; A=f[i][max(j,k)][A]; } else break; } } } fo(i,1,n) dp[0][i][0][1]=1; N=0; dfs(0,1); Dfs(0,1); printf("%lld\n",(ans+mod)%mod); fclose(stdin); fclose(stdout); return 0; }