多项式全家桶(更新至快速幂)
开始爆肝多项式
1. FFT快速傅里叶变换
流程: 将多项式\(\Theta (nlog_n)\)转成点值表示形式 进行卷积, 再\(\Theta (nlog_n)\) 转回来
离散傅里叶变换:
朴素转为点值, 需要将一个一个x带入, 而这里傅里叶搞到了几个可以优化的复数根
利用复数, 在复平面上画出一个单位圆, 将单位圆n等分, 每一个等分点为$\omega_n^i $,
接下来将\(\omega^1_n \omega^2_n\cdots\omega^n_n\) n个根带入, 其中\(\omega_n^i = cos(\frac{k}{n}2\pi) + i * sin(\frac{k}{n}2\pi)\)
稍等, 还要几个小性质才可以:
性质一: \(\omega^{2k}_{2n} = \omega_n^k\)
性质二: \(\omega_n^{k+\frac{n}{2}} = - \omega_n^k\)
接下来开始快速傅里叶变换:
设A(x) = \(a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1}\)
利用分治, 将A按x的指数分为奇偶两部分
\(A(x)=(a_0+a_2x^2+\cdots + a_{n-2}x^{n-2})+(a_1x+a_3x^3+\cdots\) $ a_{n-1}x^{n-1})$
设偶部分 \(A_1(x) = a_0 + a_2x + a_4x^2 + \cdots + a_{n-2}x^{\frac{n}{2}-1}\)
奇部分\(A_2(x) = a_1 + a_3x + a_5x^2 + \cdots + a_{n-1}x^{\frac{n}{2}-1}\)
则有: \(A(x) = A_1(x^2) + x A_2(x^2)\)
设k < n/2 带入\(x = \omega^k_n\)
\(A(\omega^k_n) = A_1(\omega^{2k}_n) + \omega_n^kA_2(\omega^{2k}_n)\)
再带入 \(\omega^{k+\frac{n}{2}}_n\)
得\(A(\omega^{k+\frac{n}{2}}_n) = A_1(\omega^{2k}_n) -\omega_n^kA_2(\omega^{2k}_n)\)
你fa♂现了吗, 分治时, 两个求值一次解决
IDFT: 离散傅里叶逆变换
设\((y_0,y_1,\cdots,y_{n-1})\) 为多项式A(x) = \(a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1}\)的离散傅里叶变换
设多项式 B(x) = A(x) = \(y_0 + y_1x + y_2x^2 + \cdots + y_{n-1}x^{n-1}\)
将n个单位根的倒数带入得新的离散傅里叶变换
此处略去推导, 得\(a_i = \frac{z_i}{n}\)
非递归版fft
每个数递归到底层时的二进制表示恰是原来的表示翻转得到的
如 6(011) 到 3(110)
再加上蝴蝶变换就可以啦
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> using namespace std; const int N = 3000600; template <typename T> void read(T &x) { x = 0; bool f = 0; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); if (f) x = -x; } struct Complex { double x, y; Complex (double xx = 0,double yy = 0) {x=xx,y=yy;} Complex operator + (Complex &i) const { return Complex(x + i.x, y + i.y); } Complex operator - (Complex &i) const { return Complex(x - i.x, y - i.y); } Complex operator * (Complex &i) const { return Complex(x * i.x - y * i.y, x * i.y + y * i.x); } }A[N], B[N]; const double Pi = acos(-1.0); int n, m; int lim = 1, L; int r[N]; void FFT(Complex *a,int type) { for (int i = 0;i < lim; i++) if (r[i] > i) swap(a[r[i]], a[i]); for (int mid = 1;mid < lim; mid <<= 1) { Complex T(cos(Pi/mid), type * sin(Pi/mid)); for (int j = 0;j < lim; j += (mid << 1)) { Complex t(1, 0); for (int k = 0;k < mid; k++, t = t * T) { Complex x = a[j + k], y = t * a[mid + j + k]; a[k + j] = x + y; a[mid + j + k] = x - y; } } } } int main() { read(n), read(m); for (int i = 0;i <= n; i++) { int x; read(x); A[i].x = x; } for (int i = 0;i <= m; i++) { int x; read(x); B[i].x = x; } while (lim <= (n + m)) lim <<= 1, L++; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); FFT(A, 1); FFT(B, 1); for (int i = 0;i <= lim; i++) A[i] = A[i] * B[i]; FFT(A, -1); for (int i = 0;i <= n + m; i++) printf ("%d ", (int)(A[i].x / lim + 0.5)); return 0; }
2.NTT快速数论变换
直接上代码
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int N = 3e6+6; const int P = 998244353; const int G = 3; const int Gi = (P+1)/G; inline int read(void) { int x = 0, f = 1; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=-1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+c-'0'; return x * f; } int n, m, r[N], L; ll A[N], B[N]; int lim = 1; ll pow(ll di,ll mi) { ll ans = 1, a = di; while (mi) { if (mi & 1) ans = ans * a % P; a = a * a % P; mi >>= 1; } return ans; } void NTT(ll *a, int tag) { for (int i = 0;i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]); for (int mid = 1;mid < lim ;mid <<= 1) { ll wn = pow(tag == 1 ? G : Gi, (P-1) / (mid << 1)); for (int j = 0;j < lim; j += (mid << 1)) { ll w = 1; for (int k = 0;k < mid ; k++, w = (w * wn) % P) { ll x = a[j+k], y = w * a[j+k+mid] % P; a[j+k] = (x + y) % P; a[j+k+mid] = (x - y + P) % P; } } } } int main() { n = read(), m = read(); for (int i = 0;i <= n; i++) A[i] = read(); for (int j = 0;j <= m; j++) B[j] = read(); while (lim <= n + m) L++, lim <<= 1; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); NTT(A, 1), NTT(B, 1); for (int i = 0;i <= lim; i++) A[i] = (A[i] * B[i]) % P; NTT(A, -1); ll inv = pow(lim, P-2); for (int i = 0;i <= n + m; i++) printf ("%d ", A[i] * inv % P); return 0; }
3.多项式求逆
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int P = 998244353; const int G = 3; const int Gi = (P+1)/G; const int N = 3e5+5; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=-1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1) + c-'0'; x *= f; } ll A[N], n, lim = 1, L; int r[N]; ll pow(ll di,ll mi) { ll ans = 1; while (mi) { if (mi & 1) ans = (ans * di) % P; di = (di * di) % P; mi >>= 1; } return ans; } void NTT(ll *a, int type) { for (int i = 0;i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = pow(type == 1 ? G : Gi, (P-1)/(mid<<1)); for (int j = 0;j < lim;j += (mid << 1)) { ll w = 1; for (int k = 0;k < mid; k++, w = (w * wn) % P) { ll x = a[j+k], y = w * a[j+k+mid] % P; a[j+k] = (x + y) % P, a[j+k+mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = pow(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll B[N], C[N]; void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = pow(a[0], P-2); return; } work((deg+1)>>1, a, b); lim = 1, L = 0; while (lim < (deg<<1)) lim <<= 1, L++; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); for (int i = 0;i < deg; i++) C[i] = a[i]; for (int i = deg;i < lim; i++) C[i] = 0; NTT(C, 1), NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = ((ll)2 - C[i] * b[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } int main() { read(n); for (int i = 0;i < n; i++) read(A[i]); work(n, A, B); for (int i = 0;i < n; i++) printf ("%lld ", B[i]); return 0; }
4.分治FFT
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int P = 998244353; const int G = 3, Gi = (P+1) / G; const int N = 405000; ll fpw(ll di, ll mi) { ll res = 1; while (mi) { if (mi & 1) res = res * di % P; di = di * di % P; mi >>= 1; } return res; } int lim, L; ll r[N]; void NTT(ll *a,int type) { for (int i = 0;i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = fpw(type == 1 ? G : Gi, (P-1) / (mid<<1)); for (int i = 0;i < lim; i += mid << 1) { ll w = 1; for (int j = 0;j < mid; j++, w = w * wn % P) { ll x = a[j + i], y = a[mid + i + j] * w % P; a[i + j] = (x + y) % P; a[i + j + mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = fpw(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll a[N], f[N], b[N], g[N]; void solve(int l,int R) { if (l == R) return; int mid = (l + R) >> 1; solve(l, mid); lim = 1, L = 0; int len = R - l + 1; for (int i = l;i <= mid; i++) a[i-l] = f[i]; for (int i = 0;i < len; i++) b[i] = g[i]; len += mid - l + 1; while (lim <= len) lim <<= 1, L++; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1)<<(L-1)); NTT(a, 1), NTT(b, 1); for (int i = 0;i < lim; i++) a[i] = a[i] * b[i] % P; NTT(a, -1); for (int i = mid + 1;i <= R; i++) f[i] = (f[i] + a[i-l]) % P; for (int i = 0;i < lim; i++) a[i] = b[i] = 0; solve(mid + 1, R); } int read(void) { int x = 0; char c = getchar(); while (!isdigit(c)) c=getchar(); for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); return x; } int n; int main() { n = read(); for (int i = 1;i < n; i++) g[i] = read(); f[0] = 1; solve(0, n - 1); for (int i = 0;i < n; i++) printf ("%lld ", f[i]); return 0; }
5.多项式ln:
#include<iostream> #include<cstring> #include<cstdio> #define ll long long using namespace std; const int N = 505000; const int P = 998244353; const int G = 3; const int Gi = (P+1) / G; int read(void) { int x = 0; bool f = 0; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); return f ? -x : x; } ll fpw(ll di,ll mi) { ll ans = 1; while (mi) { if (mi & 1) ans = ans * di % P; di = di * di % P; mi >>= 1; } return ans; } ll A[N], B[N], C[N]; int r[N], n, lim, L; void NTT(ll *a, int type) { for (int i = 0;i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]); for (int j = 1;j < lim; j <<= 1) { ll wn = fpw(type == 1 ? G : Gi, (P-1) / (j<<1)); for (int k = 0;k < lim; k += j << 1) { ll w = 1; for (int p = 0;p < j; p++, w = w * wn % P) { ll x = a[p+k], y = a[p+k+j] * w % P; a[p+k] = (x + y) % P; a[p+k+j] = (x - y + P) % P; } } } if (~type) return; ll inv = fpw(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = fpw(a[0], P-2); return; } work((deg+1)>>1, a, b); lim = 1, L = 0; while (lim < (deg<<1)) lim <<= 1, L++; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1)<<(L-1)); for (int i = 0;i < deg; i++) C[i] = a[i]; for (int i = deg;i < lim; i++) C[i] = 0; NTT(C, 1); NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = ((ll)2 - C[i] * b[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } int m; void qiudoor(void) { for (int i = 1;i < n; i++) A[i-1] = A[i] * i % P; A[n-1] = 0; } void jinitaimei(void) { for (int i = n-1;i >= 1; i--) A[i] = A[i-1] * fpw(i, P-2) % P; A[0] = 0; } int main() { n = read(); for (int i = 0;i < n; i++) A[i] = read(); work(n, A, B); qiudoor(); lim = 1, L = 0; m = n; while (lim <= n + m) lim <<= 1, L++; NTT(A, 1); NTT(B, 1); for (int i = 0;i < lim; i++) A[i] = A[i] * B[i] % P; NTT(A, -1); jinitaimei(); for (int i = 0;i < n; i++) printf ("%lld ", A[i] % P); return 0; }
6.多项式exp
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int P = 998244353; const int G = 3, Gi = (P+1) / G; const int N = 1005000; int read(void) { int x = 0; char c = getchar(); while (!isdigit(c)) c=getchar(); for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); return x; } ll fpw(ll di, ll mi) { ll res = 1; while (mi) { if (mi & 1) res = res * di % P; di = di * di % P; mi >>= 1; } return res; } int lim, L; int r[N]; void NTT(ll *a,int type) { for (int i = 0;i < lim; i++) if (r[i] > i) swap(a[r[i]], a[i]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = fpw(type == 1 ? G : Gi, (P-1) / (mid << 1)); for (int i = 0;i < lim; i += mid << 1) { ll w = 1; for (int j = 0;j < mid; j++, w = w * wn % P) { int x = a[i + j], y = a[i + j + mid] * w % P; a[i + j] = (x + y) % P; a[i + j + mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = fpw(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll a[N], b[N], c[N]; void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = fpw(a[0], P-2); return; } work((deg+1) >> 1, a, b); lim = 1, L = 0; while (lim < (deg << 1)) lim <<= 1, L++; for (int i = 0;i < lim ;i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); for (int i = 0;i < deg; i++) c[i] = a[i]; for (int i = deg; i < lim; i++) c[i] = 0; NTT(b, 1); NTT(c, 1); for (int i = 0;i < lim; i++) b[i] = (2 - b[i] * c[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } int n; void qiudoor(ll *a, int deg) { for (int i = 1;i < n; i++) a[i-1] = a[i] * i % P; a[n-1] = 0; } ll inv[N]; void init(void) { inv[1] = 1; for (int i = 2;i <= n * 2; i++) inv[i] = (P - P / i * inv[P%i] % P) % P; } void jinitaimei(ll *a, int deg) { for (int i = n-1;i >= 1; i--) a[i] = a[i-1] * inv[i] % P; a[0] = 0; } ll A[N], B[N]; void get_ln(ll *a, ll n) { for (int i = 0;i < n*2; i++) B[i] = 0; // put(A, n), put(B, n*2); work(n, a, B); qiudoor(a, n); NTT(a, 1); NTT(B, 1); for (int i = 0;i < lim; i++) a[i] = a[i] * B[i] % P; NTT(a, -1); jinitaimei(a, n); } void solve(int deg, ll *a, ll *b) { if (deg == 1) return (void) (b[0] = 1); solve(deg>>1, a, b); for (int i = 0;i < deg; i++) A[i] = b[i]; get_ln(A, deg); A[0] = (a[0] + 1 - A[0] + P) % P; for (int i = 1;i < deg; i++) A[i] = (a[i] - A[i] + P) % P; NTT(A, 1); NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = b[i] * A[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } ll f[N], g[N]; int main() { n = read(); init(); for (int i = 0;i < n; i++) f[i] = read(); lim = 1; while (lim <= n) lim <<= 1; solve(lim, f, g); for (int i = 0;i < n; i++) printf ("%d ", g[i]); return 0; }
7.多项式开根
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int P = 998244353; const int G = 3, Gi = (P+1) / G; const int N = 1005000; int read(void) { int x = 0; char c = getchar(); while (!isdigit(c)) c=getchar(); for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); return x; } ll fpw(ll di, ll mi) { ll res = 1; while (mi) { if (mi & 1) res = res * di % P; di = di * di % P; mi >>= 1; } return res; } int lim, L; int r[N]; void NTT(ll *a,int type) { for (int i = 0;i < lim; i++) if (r[i] > i) swap(a[r[i]], a[i]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = fpw(type == 1 ? G : Gi, (P-1) / (mid << 1)); for (int i = 0;i < lim; i += mid << 1) { ll w = 1; for (int j = 0;j < mid; j++, w = w * wn % P) { int x = a[i + j], y = a[i + j + mid] * w % P; a[i + j] = (x + y) % P; a[i + j + mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = fpw(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll a[N], b[N], c[N]; void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = fpw(a[0], P-2); return; } work((deg+1) >> 1, a, b); lim = 1, L = 0; while (lim < (deg << 1)) lim <<= 1, L++; for (int i = 0;i < lim ;i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); for (int i = 0;i < deg; i++) c[i] = a[i]; for (int i = deg; i < lim; i++) c[i] = 0; NTT(b, 1); NTT(c, 1); for (int i = 0;i < lim; i++) b[i] = (2 - b[i] * c[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } int n; void qiudoor(ll *a, int deg) { for (int i = 1;i < n; i++) a[i-1] = a[i] * i % P; a[n-1] = 0; } ll inv[N]; void init(void) { inv[1] = 1; for (int i = 2;i <= n * 2; i++) inv[i] = (P - P / i * inv[P%i] % P) % P; } void jinitaimei(ll *a, int deg) { for (int i = n-1;i >= 1; i--) a[i] = a[i-1] * inv[i] % P; a[0] = 0; } ll A[N], B[N]; void get_ln(ll *a, ll n) { for (int i = 0;i < n*2; i++) B[i] = 0; // put(A, n), put(B, n*2); work(n, a, B); qiudoor(a, n); NTT(a, 1); NTT(B, 1); for (int i = 0;i < lim; i++) a[i] = a[i] * B[i] % P; NTT(a, -1); jinitaimei(a, n); } void solve(int deg, ll *a, ll *b) { if (deg == 1) return (void) (b[0] = 1); solve(deg>>1, a, b); lim = 1, L = 0; for (int i = 0;i < deg; i++) A[i] = b[i]; get_ln(A, deg); A[0] = (a[0] + 1 - A[0] + P) % P; for (int i = 1;i < deg; i++) A[i] = (a[i] - A[i] + P) % P; NTT(A, 1); NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = b[i] * A[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } ll f[N], g[N]; int main() { n = read(); init(); for (int i = 0;i < n; i++) f[i] = read(); get_ln(f, n); for (int i = 0;i < n; i++) f[i] = f[i] * inv[2] % P; lim = 1; while (lim <= n) lim <<= 1; solve(lim, f, g); for (int i = 0;i < n; i++) printf ("%d ", g[i]); return 0; }
8.多项式除法
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int P = 998244353; const int G = 3; const int Gi = (P+1)/G; const int N = 3e5+5; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=-1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1) + c-'0'; x *= f; } ll A[N], n, lim = 1, L; int r[N]; ll pow(ll di,ll mi) { ll ans = 1; while (mi) { if (mi & 1) ans = (ans * di) % P; di = (di * di) % P; mi >>= 1; } return ans; } void NTT(ll *a, int type) { for (int i = 0;i < lim; i++) if (i < r[i]) swap(a[i], a[r[i]]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = pow(type == 1 ? G : Gi, (P-1)/(mid<<1)); for (int j = 0;j < lim;j += (mid << 1)) { ll w = 1; for (int k = 0;k < mid; k++, w = (w * wn) % P) { ll x = a[j+k], y = w * a[j+k+mid] % P; a[j+k] = (x + y) % P, a[j+k+mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = pow(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll B[N], C[N]; void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = pow(a[0], P-2); return; } work((deg+1)>>1, a, b); lim = 1, L = 0; while (lim < (deg<<1)) lim <<= 1, L++; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); for (int i = 0;i < deg; i++) C[i] = a[i]; for (int i = deg;i < lim; i++) C[i] = 0; NTT(C, 1), NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = ((ll)2 - C[i] * b[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } ll f[N], g[N], q[N], ff[N], gg[N]; int m; int main() { freopen("hs.in","r",stdin); freopen("hs.out","w",stdout); read(n), read(m); for (int i = n;i >= 0; i--) read(f[i]), ff[n-i] = f[i]; for (int j = m;j >= 0; j--) read(g[j]), gg[m-j] = g[j]; for (int i = n - m + 1;i <= n; i++) f[i] = g[i] = 0; work(n-m+1, g, B); NTT(B, 1); NTT(f, 1); for (int i = 0;i < lim; i++) q[i] = f[i] * B[i] % P; NTT(q, -1); reverse(q, q + n - m + 1); for (int i = 0;i <= n - m; i++) printf ("%lld ", q[i]); lim = 1, L = 0; while (lim <= n) lim <<= 1, L++; for (int i = n - m + 1;i < lim; i++) q[i] = 0; for (int i = 0;i < lim; i++) r[i] = (r[i>>1]>>1) | ((i&1)<<(L-1)); NTT(gg, 1); NTT(q, 1); for (int i = 0;i < lim; i++) gg[i] = gg[i] * q[i] % P; NTT(gg, -1); putchar('\n'); for (int i = 0;i < m; i++) printf ("%lld ", (ff[i] - gg[i] + P) % P); return 0; }
9.拉格朗日插值
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int P = 998244353; const int N = 3e5+5; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (;!isdigit(c);c=getchar()) if (c=='-') f=-1; for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1) + c-'0'; x *= f; } ll fpw(ll x,ll mi) { ll res = 1; while (mi) { if (mi & 1) res = res * x % P; x = x * x % P; mi >>= 1; } return res; } int x[N], y[N], n, k; int main() { read(n), read(k); for (int i = 1;i <= n; i++) read(x[i]), read(y[i]); ll res = 0; for (int i = 1;i <= n; i++) { ll ans1 = 1, ans2 = 1; for (int j = 1;j <= n; j++) { if (j == i) continue; ans1 = ans1 * (k - x[j]) % P; ans2 = ans2 * (x[i] - x[j]) % P; } res += y[i] * ans1 % P * fpw(ans2, P-2) % P; } res %= P; if (res < 0) res += P; cout << res << endl; return 0; }
10. 多项式快速幂
#include<iostream> #include<cstdio> #include<cstring> #define ll long long using namespace std; const int P = 998244353; const int G = 3, Gi = (P+1) / G; const int N = 1005000; int read(void) { int x = 0; char c = getchar(); while (!isdigit(c)) c=getchar(); for (;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48); return x; } ll fpw(ll di, ll mi) { ll res = 1; while (mi) { if (mi & 1) res = res * di % P; di = di * di % P; mi >>= 1; } return res; } int lim, L; int r[N]; void NTT(ll *a,int type) { for (int i = 0;i < lim; i++) if (r[i] > i) swap(a[r[i]], a[i]); for (int mid = 1;mid < lim; mid <<= 1) { ll wn = fpw(type == 1 ? G : Gi, (P-1) / (mid << 1)); for (int i = 0;i < lim; i += mid << 1) { ll w = 1; for (int j = 0;j < mid; j++, w = w * wn % P) { int x = a[i + j], y = a[i + j + mid] * w % P; a[i + j] = (x + y) % P; a[i + j + mid] = (x - y + P) % P; } } } if (type == 1) return; ll inv = fpw(lim, P-2); for (int i = 0;i < lim; i++) a[i] = a[i] * inv % P; } ll a[N], b[N], c[N]; void work(int deg, ll *a, ll *b) { if (deg == 1) { b[0] = fpw(a[0], P-2); return; } work((deg+1) >> 1, a, b); lim = 1, L = 0; while (lim < (deg << 1)) lim <<= 1, L++; for (int i = 0;i < lim ;i++) r[i] = (r[i>>1]>>1) | ((i&1) << (L-1)); for (int i = 0;i < deg; i++) c[i] = a[i]; for (int i = deg; i < lim; i++) c[i] = 0; NTT(b, 1); NTT(c, 1); for (int i = 0;i < lim; i++) b[i] = (2 - b[i] * c[i] % P + P) % P * b[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } int n; void qiudoor(ll *a, int deg) { for (int i = 1;i < n; i++) a[i-1] = a[i] * i % P; a[n-1] = 0; } ll inv[N]; void init(void) { inv[1] = 1; for (int i = 2;i <= n * 2; i++) inv[i] = (P - P / i * inv[P%i] % P) % P; } void jinitaimei(ll *a, int deg) { for (int i = n-1;i >= 1; i--) a[i] = a[i-1] * inv[i] % P; a[0] = 0; } ll A[N], B[N]; void get_ln(ll *a, ll n) { for (int i = 0;i < n*2; i++) B[i] = 0; // put(A, n), put(B, n*2); work(n, a, B); qiudoor(a, n); NTT(a, 1); NTT(B, 1); for (int i = 0;i < lim; i++) a[i] = a[i] * B[i] % P; NTT(a, -1); jinitaimei(a, n); } void solve(int deg, ll *a, ll *b) { if (deg == 1) return (void) (b[0] = 1); solve(deg>>1, a, b); for (int i = 0;i < deg; i++) A[i] = b[i]; get_ln(A, deg); A[0] = (a[0] + 1 - A[0] + P) % P; for (int i = 1;i < deg; i++) A[i] = (a[i] - A[i] + P) % P; NTT(A, 1); NTT(b, 1); for (int i = 0;i < lim; i++) b[i] = b[i] * A[i] % P; NTT(b, -1); for (int i = deg;i < lim; i++) b[i] = 0; } ll f[N], g[N]; ll get_k(void) { ll x = 0; char c = getchar(); while (!isdigit(c)) c=getchar(); for (;isdigit(c);c=getchar()) { x=(x<<3)+(x<<1)+(c^48); if (x >= P) x %= P; } return x; } ll k; int main() { n = read(); init(); k = get_k(); for (int i = 0;i < n; i++) f[i] = read(); get_ln(f, n); for (int i = 0;i < n; i++) f[i] = f[i] * k % P; lim = 1; while (lim <= n) lim <<= 1; solve(lim, f, g); for (int i = 0;i < n; i++) printf ("%d ", g[i]); return 0; }