洛谷P4238【模板】多项式求逆

◇◆丶佛笑我妖孽 提交于 2020-01-16 05:38:28

洛谷P4238

多项式求逆:http://blog.miskcoo.com/2015/05/polynomial-inverse

注意:直接在点值表达下做$B(x) \equiv 2B'(x) - A(x)B'^2(x) \pmod {x^n}$是可以的,但是一定要注意,这一步中有一个长度为n的和两个长度为(n/2)的多项式相乘,因此要在DFT前就扩展FFT点值表达的“长度”到2n,否则会出错(调了1.5个小时)

备份

版本1:

 1 #prag\
 2 ma GCC optimize(2)
 3 #include<cstdio>
 4 #include<algorithm>
 5 #include<cstring>
 6 #include<vector>
 7 #include<cmath>
 8 using namespace std;
 9 #define fi first
10 #define se second
11 #define mp make_pair
12 #define pb push_back
13 typedef long long ll;
14 typedef unsigned long long ull;
15 const int md=998244353;
16 const int N=2097152;
17 int rev[N];
18 void init(int len)
19 {
20     int bit=0,i;
21     while((1<<(bit+1))<=len)    ++bit;
22     for(i=0;i<len;++i)
23         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
24 }
25 ll poww(ll a,ll b)
26 {
27     ll base=a,ans=1;
28     for(;b;b>>=1,base=base*base%md)
29         if(b&1)
30             ans=ans*base%md;
31     return ans;
32 }
33 void dft(int *a,int len,int idx)//要求len为2的幂
34 {
35     int i,j,k,t1,t2;ll wn,wnk;
36     for(i=0;i<len;++i)
37         if(i<rev[i])
38             swap(a[i],a[rev[i]]);
39     for(i=1;i<len;i<<=1)
40     {
41         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
42         for(j=0;j<len;j+=(i<<1))
43         {
44             wnk=1;
45             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
46             {
47                 t1=a[k];t2=a[k+i]*wnk%md;
48                 a[k]+=t2;
49                 (a[k]>=md) && (a[k]-=md);
50                 a[k+i]=t1-t2;
51                 (a[k+i]<0) && (a[k+i]+=md);
52             }
53         }
54     }
55     if(idx==-1)
56     {
57         ll ilen=poww(len,md-2);
58         for(i=0;i<len;++i)
59             a[i]=a[i]*ilen%md;
60     }
61 }
62 int f[N],g[N],t1[N];
63 int n,n1;
64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) 
65 {
66     g[0]=poww(f[0],md-2);
67     for(int i=2,j;i<(len<<1);i<<=1)
68     {
69         init(i<<1);
70         memcpy(t1,f,sizeof(int)*i);
71         memset(t1+i,0,sizeof(int)*i);
72         memset(g+(i>>1),0,sizeof(int)*(i+(i>>1)));
73         dft(t1,i<<1,1);dft(g,i<<1,1);
74         for(j=0;j<(i<<1);++j)
75             g[j]=ll(g[j])*(2+ll(md-g[j])*t1[j]%md)%md;
76         dft(g,i<<1,-1);
77     }
78 }
79 int main()
80 {
81     int i,t;
82     scanf("%d",&n);n1=n;
83     for(i=0;i<n;++i)
84         scanf("%d",g+i);
85     for(t=1;t<n;t<<=1);
86     n=t;
87     p_inv(g,f,n);
88     for(i=0;i<n1;++i)
89         printf("%d ",f[i]);
90     return 0;
91 }
View Code

资料:https://www.luogu.org/blog/user7035/duo-xiang-shi-zong-jie

里面有一个迷之优化(代码好像和文字表述的不一样,很玄学,看不懂,被坑了...)

牛顿迭代得到式子:$B(x) \equiv B'(x)-B'(x)(A(x)B'(x)-1) \pmod {x^n}$,其中B'(x)是上一次迭代的结果,B(x)是这一次的结果,A(x)是原多项式,n是这一次迭代得到的结果长度(设它是2的幂);设上一次迭代得到的结果长度为m=n/2

看右边的$A(x)B'(x)-1$,可以知道它第0到m-1项都是0,现在只需要求它与B'(x)的乘积的前n位,可以把它”左移“m位,这样它和B'(x)长度都只有m,因此只需要做长度为n(而不是2n)的NTT,然后再”右移”回去

如果与B'(x)相乘时不做长度为2n的NTT而做长度为n的NTT,那么可以发现结果刚好相当于正常结果(做长度为2n的NTT的结果取前n位)将前一半和后一半交换(未验证)

(可以直接用算A(x)B'(x)时求出的B'(x)的DFT)(当然这样NTT次数从3次变成了5次...)

版本2:(实测的确比版本1快)(另外把longlong都改成了unsignedlonglong)

  1 #prag\
  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=262144;
 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 18 int rev[N];
 19 void init(int len)
 20 {
 21     int bit=0,i;
 22     while((1<<(bit+1))<=len)    ++bit;
 23     for(i=0;i<len;++i)
 24         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 25 }
 26 ull poww(ull a,ull b)
 27 {
 28     ull base=a,ans=1;
 29     for(;b;b>>=1,base=base*base%md)
 30         if(b&1)
 31             ans=ans*base%md;
 32     return ans;
 33 }
 34 void dft(int *a,int len,int idx)//要求len为2的幂
 35 {
 36     int i,j,k,t1,t2;ull wn,wnk;
 37     for(i=0;i<len;++i)
 38         if(i<rev[i])
 39             swap(a[i],a[rev[i]]);
 40     for(i=1;i<len;i<<=1)
 41     {
 42         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 43         for(j=0;j<len;j+=(i<<1))
 44         {
 45             wnk=1;
 46             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 47             {
 48                 t1=a[k];t2=a[k+i]*wnk%md;
 49                 a[k]+=t2;
 50                 (a[k]>=md) && (a[k]-=md);
 51                 a[k+i]=t1-t2;
 52                 (a[k+i]<0) && (a[k+i]+=md);
 53             }
 54         }
 55     }
 56     if(idx==-1)
 57     {
 58         ull ilen=poww(len,md-2);
 59         for(i=0;i<len;++i)
 60             a[i]=a[i]*ilen%md;
 61     }
 62 }
 63 int t1[N],t2[N];
 64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2^(ceil(log2(len))+1)(需要足够长用于临时存放元素) ;要求len是2的幂
 65 {
 66     g[0]=poww(f[0],md-2);
 67     for(int i=2,j;i<(len<<1);i<<=1)
 68     {
 69         memcpy(t1,f,sizeof(int)*i);
 70         memcpy(t2,g,sizeof(int)*(i>>1));
 71         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 72         init(i);
 73         dft(t1,i,1);dft(t2,i,1);
 74         for(j=0;j<i;++j)
 75             t1[j]=ull(t1[j])*t2[j]%md;
 76         dft(t1,i,-1);
 77         for(j=0;j<(i>>1);++j)
 78             t1[j]=t1[j+(i>>1)];
 79         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 80         dft(t1,i,1);
 81         for(j=0;j<i;++j)
 82             t1[j]=ull(t1[j])*t2[j]%md;
 83         dft(t1,i,-1);
 84         for(j=i>>1;j<i;++j)
 85             delto(g[j],t1[j-(i>>1)]);
 86     }
 87 }
 88 int f[N],g[N];
 89 int n,n1;
 90 int main()
 91 {
 92     int i,t;
 93     scanf("%d",&n);n1=n;
 94     for(i=0;i<n;++i)
 95         scanf("%d",g+i);
 96     for(t=1;t<n;t<<=1);
 97     n=t;
 98     p_inv(g,f,n);
 99     for(i=0;i<n1;++i)
100         printf("%d ",f[i]);
101     return 0;
102 }
View Code

版本3:基于此题版本2,改了疑似bug

  1 #prag\
  2 ma GCC optimize(2)
  3 #include<cstdio>
  4 #include<algorithm>
  5 #include<cstring>
  6 #include<vector>
  7 #include<cmath>
  8 using namespace std;
  9 #define fi first
 10 #define se second
 11 #define mp make_pair
 12 #define pb push_back
 13 typedef long long ll;
 14 typedef unsigned long long ull;
 15 const int md=998244353;
 16 const int N=262144;
 17 #define delto(a,b) ((a)-=(b),((a)<0)&&((a)+=md))
 18 int rev[N];
 19 void init(int len)
 20 {
 21     int bit=0,i;
 22     while((1<<(bit+1))<=len)    ++bit;
 23     for(i=0;i<len;++i)
 24         rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
 25 }
 26 ull poww(ull a,ull b)
 27 {
 28     ull base=a,ans=1;
 29     for(;b;b>>=1,base=base*base%md)
 30         if(b&1)
 31             ans=ans*base%md;
 32     return ans;
 33 }
 34 void dft(int *a,int len,int idx)//要求len为2的幂
 35 {
 36     int i,j,k,t1,t2;ull wn,wnk;
 37     for(i=0;i<len;++i)
 38         if(i<rev[i])
 39             swap(a[i],a[rev[i]]);
 40     for(i=1;i<len;i<<=1)
 41     {
 42         wn=poww(idx==1?3:332748118,(md-1)/(i<<1));
 43         for(j=0;j<len;j+=(i<<1))
 44         {
 45             wnk=1;
 46             for(k=j;k<j+i;++k,wnk=wnk*wn%md)
 47             {
 48                 t1=a[k];t2=a[k+i]*wnk%md;
 49                 a[k]+=t2;
 50                 (a[k]>=md) && (a[k]-=md);
 51                 a[k+i]=t1-t2;
 52                 (a[k+i]<0) && (a[k+i]+=md);
 53             }
 54         }
 55     }
 56     if(idx==-1)
 57     {
 58         ull ilen=poww(len,md-2);
 59         for(i=0;i<len;++i)
 60             a[i]=a[i]*ilen%md;
 61     }
 62 }
 63 int t1[N],t2[N];
 64 void p_inv(int *f,int *g,int len)//g=f^(-1);f,g数组的长度不小于2len(需要足够长用于临时存放元素) ;要求len是2的幂
 65 {
 66     g[0]=poww(f[0],md-2);
 67     for(int i=2,j;i<(len<<1);i<<=1)
 68     {
 69         memcpy(t1,f,sizeof(int)*i);
 70         memcpy(t2,g,sizeof(int)*(i>>1));
 71         memset(t2+(i>>1),0,sizeof(int)*(i>>1));
 72         init(i);
 73         dft(t1,i,1);dft(t2,i,1);
 74         for(j=0;j<i;++j)
 75             t1[j]=ull(t1[j])*t2[j]%md;
 76         dft(t1,i,-1);
 77         for(j=0;j<(i>>1);++j)
 78             t1[j]=t1[j+(i>>1)];
 79         memset(t1+(i>>1),0,sizeof(int)*(i>>1));
 80         dft(t1,i,1);
 81         for(j=0;j<i;++j)
 82             t1[j]=ull(t1[j])*t2[j]%md;
 83         dft(t1,i,-1);
 84         for(j=i>>1;j<i;++j)
 85             g[j]=md-t1[j-(i>>1)];
 86     }
 87 }
 88 int f[N],g[N];
 89 int n,n1;
 90 int main()
 91 {
 92     int i,t;
 93     scanf("%d",&n);n1=n;
 94     for(i=0;i<n;++i)
 95         scanf("%d",g+i);
 96     for(t=1;t<n;t<<=1);
 97     n=t;
 98     p_inv(g,f,n);
 99     for(i=0;i<n1;++i)
100         printf("%d ",f[i]);
101     return 0;
102 }
View Code

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!