超好写的
#include<stdio.h> #define inf 998244353 #define I 86583718 #define MAXN 8388608 char buf[1<<20],*p1,*p2; #define GC (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?0:*p1++) inline int R(){ char t=GC; int x=0; while(t<'!')t=GC; while(t>='!')x=x*10+t-48,t=GC; return x; } int _0[MAXN],_lset=-1,inv[MAXN],tmp[MAXN]; inline void init(){ inv[1]=1; for(int i=2;i<MAXN;i++)inv[i]=inf-1ull*inv[inf%i]*(inf/i)%inf; } inline void mod(int &p){p>=inf?(p-=inf):0;} inline int getinv(int p){return p<3e6+5?inv[p]:inf-1ull*getinv(inf%p)*(inf/p)%inf;} inline int ksm(long long a,int b){int ans=1;while(b)(b&1)&&(ans=a*ans%inf),a=a*a%inf,b>>=1;return ans;} inline void init(int n){if(_lset==n)return;_lset=n;_0[0]=0;for(int i=1;i<n;++i)_0[i]=i&1?_0[i^1]|n>>1:_0[i>>1]>>1;} inline void mul(int a[],int b[],int n){for(register int i=0;i<n;++i)a[i]=1ull*a[i]*b[i]%inf;} inline void clr(int a[],int l,int r){for(register int i=l;i<r;++i)a[i]=0;} inline void cpy(int a[],int b[],int n){for(register int i=0;i<n;++i)b[i]=a[i];} inline void rev(int a[],int b[],int n){for(register int i=0;i<n;++i)b[i]=a[n-i-1];} inline int gett(int n){while(n!=(n&-n))n^=n&-n;return n<<1;} inline void ntt(int a[],int n,bool typ){ init(n);register int p,q; for(register int i=0;i<n;++i)if(i<_0[i])a[i]^=a[_0[i]]^=a[i]^=a[_0[i]]; for(register int i=1;i<n;i<<=1){ register int w=ksm(typ?332748118:3,(inf-1>>1)/i); register unsigned long long h=1; for(register int k=0;k<i;++k,h=w*h%inf)tmp[k]=h; for(register int j=0;j<n;j+=i<<1) for(register int k=0;k<i;++k) p=a[j+k],q=1ull*tmp[k]*a[i+j+k]%inf,mod(a[j+k]=p+q),mod(a[i+j+k]=p+inf-q); }if(typ){ register unsigned long long t=ksm(n,inf-2); for(register int i=0;i<n;i++)a[i]=t*a[i]%inf; } } inline void getinv(int a[],int b[],int tmp[],int n){ clr(b,0,n<<1); b[0]=ksm(a[0],inf-2); for(int i=1;i<n;i<<=1){ clr(tmp,0,i<<2); cpy(a,tmp,i<<1); ntt(tmp,i<<2,0); ntt(b,i<<2,0); mul(tmp,b,i<<2); mul(tmp,b,i<<2); for(int j=0;j<i<<2;j++)mod(tmp[j]=b[j]+inf-tmp[j]),mod(tmp[j]+=b[j]); cpy(tmp,b,i<<2); ntt(b,i<<2,1); clr(b,i<<1,i<<2); } } inline void getln(int a[],int b[],int tmp[],int n){ getinv(a,b,tmp,n); clr(tmp,0,n<<1); for(int i=1;i<n;i++)tmp[i-1]=1ull*a[i]*i%inf; ntt(b,n<<1,0); ntt(tmp,n<<1,0); mul(tmp,b,n<<1); ntt(tmp,n<<1,1); clr(b,0,n<<1); for(int i=1;i<n;i++)b[i]=1ull*tmp[i-1]*inv[i]%inf; } inline void getexp(int a[],int b[],int tmp[],int tmp2[],int n){ clr(b,0,n<<1); b[0]=1; for(int i=1;i<n;i<<=1){ getln(b,tmp,tmp2,i<<1); for(int j=0;j<i<<1;j++)mod(tmp[j]=a[j]+inf-tmp[j]); mod(tmp[0]+=1); ntt(tmp,i<<2,0); ntt(b,i<<2,0); mul(b,tmp,i<<2); ntt(b,i<<2,1); } } inline void getdiv(int a[],int b[],int c[],int d[],int tmp[],int n,int m){ int len=gett(n); clr(c,0,len<<1); rev(b,c,m); getinv(c,d,tmp,len); clr(c,0,len<<1); ntt(d,len<<1,0); clr(tmp,0,len<<1); rev(a,tmp,n); ntt(tmp,len<<1,0); mul(d,tmp,len<<1); ntt(d,len<<1,1); rev(d,c,n-m+1); clr(d,0,len<<1); clr(tmp,0,len<<1); cpy(c,d,len<<1); cpy(b,tmp,len<<1); ntt(d,len<<1,0); ntt(tmp,len<<1,0); mul(d,tmp,len<<1); ntt(d,len<<1,1); for(int i=0;i<n;i++)mod(d[i]=a[i]+inf-d[i]); } inline void getsqrt(int a[],int b[],int tmp[],int tmp2[],int n){ clr(b,0,n<<1); b[0]=1; for(int i=1;i<n;i<<=1){ getinv(b,tmp,tmp2,i<<1); for(int j=0;j<i<<2;j++)tmp[j]=499122177ull*tmp[j]%inf; ntt(b,i<<2,0); mul(b,b,i<<2); ntt(b,i<<2,1); for(int j=0;j<i<<1;j++)mod(b[j]+=a[j]); ntt(b,i<<2,0); ntt(tmp,i<<2,0); mul(b,tmp,i<<2); ntt(b,i<<2,1); clr(b,i<<1,i<<2); } } inline void getsin(int a[],int b[],int tmp[],int tmp2[],int tmp3[],int n){ for(int i=0;i<n;i++)a[i]=1ull*I*a[i]%inf; getexp(a,b,tmp2,tmp3,n); for(int i=0;i<n;i++)mod(a[i]=inf-a[i]); getexp(a,tmp,tmp2,tmp3,n); for(int i=0;i<n;i++)a[i]=1ull*I*a[i]%inf,b[i]=499122177ull*(tmp[i]+inf-b[i])%inf*I%inf; } inline void getcos(int a[],int b[],int tmp[],int tmp2[],int tmp3[],int n){ for(int i=0;i<n;i++)a[i]=1ull*I*a[i]%inf; getexp(a,b,tmp2,tmp3,n); for(int i=0;i<n;i++)mod(a[i]=inf-a[i]); getexp(a,tmp,tmp2,tmp3,n); for(int i=0;i<n;i++)a[i]=1ull*I*a[i]%inf,b[i]=499122177ull*(tmp[i]+b[i])%inf; }