多项式指数函数

岁酱吖の 提交于 2020-01-02 08:17:18

给定\(n-1\)次多项式\(A(x)\),求一个\(mod\, x^n\)下的多项式\(B(x)\),满足\(B(x)≡e^{A(x)}\),系数对\(998244353\)取模

一大堆前置姿势:

微积分
多项式对数函数
多项式牛顿迭代

计算\(F(x)≡e^{A(x)}(mod\, x^n)\)

两边同时取对数得

\(lnF(x)-A(x)≡0 (mod\, x^n)\)

\(G(F(x))=lnF(x)-A(x)(mod\, x^n)\)

套一下牛顿迭代公式

\(F(x)=F_0(x)-\frac{G(F(x))}{G'(F_0(x))}(mod\, x^n)\)

\(A(x)\) 是给定的常数项,那么\(G'(x)=G(x)^{-1}\)

可以整理出来

\(F(x)=F_0(x)(1-lnF_0(x)+A(x))(mod\, x^n)\)

然后递归求解,注意边界为\(F(0)=1\)

#include<bits/stdc++.h>
using namespace std;
namespace red{
#define int long long
#define eps (1e-8)
    inline int read()
    {
        int x=0;char ch,f=1;
        for(ch=getchar();(ch<'0'||ch>'9')&&ch!='-';ch=getchar());
        if(ch=='-') f=0,ch=getchar();
        while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
        return f?x:-x;
    }
    const int N=5e5+10,p=998244353,g=3;
    int n;
    int a[N],b[N],c[N],f[N],pos[N];
    int A[N],B[N];
    inline int fast(int x,int k)
    {
        int ret=1;
        while(k)
        {
            if(k&1) ret=ret*x%p;
            x=x*x%p;
            k>>=1;
        }
        return ret;
    }
    inline void ntt(int limit,int *a,int inv)
    {
        for(int i=0;i<limit;++i)
            if(i<pos[i]) swap(a[i],a[pos[i]]);
        for(int mid=1;mid<limit;mid<<=1)
        {
            int Wn=fast(g,(p-1)/(mid<<1));
            for(int r=mid<<1,j=0;j<limit;j+=r)
            {
                int w=1;
                for(int k=0;k<mid;++k,w=w*Wn%p)
                {
                    int x=a[j+k],y=w*a[j+k+mid]%p;
                    a[j+k]=x+y;
                    if(a[j+k]>=p) a[j+k]-=p;
                    a[j+k+mid]=x-y;
                    if(a[j+k+mid]<0) a[j+k+mid]+=p;
                }
            }
        }
        if(inv) return;
        inv=fast(limit,p-2);reverse(a+1,a+limit);
        for(int i=0;i<limit;++i) a[i]=a[i]*inv%p;
    }
    inline void deriva(int *a,int *b,int n)
    {
        for(int i=1;i<n;++i) b[i-1]=a[i]*i%p;
        b[n-1]=0;
    }
    inline void integral(int *a,int n)
    {
        for(int i=n-1;i;--i) a[i]=a[i-1]*fast(i,p-2)%p;
        a[0]=0;
    }
    inline void poly_inv(int pw,int *a,int *b)
    {
        if(pw==1) {b[0]=fast(a[0],p-2);return;}
        poly_inv((pw+1)>>1,a,b);
        int len=0,limit=1;
        while(limit<(pw<<1)) limit<<=1,++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        for(int i=0;i<pw;++i) c[i]=a[i];
        for(int i=pw;i<limit;++i) c[i]=0;
        ntt(limit,c,1);ntt(limit,b,1);
        for(int i=0;i<limit;++i) b[i]=((2-c[i]*b[i]%p)+p)%p*b[i]%p;
        ntt(limit,b,0);
        for(int i=pw;i<limit;++i) b[i]=0;
    }
    inline void ln(int *a,int *b,int n)
    {
        deriva(a,A,n),poly_inv(n,a,B);
        int len=0,limit=1;
        while(limit<(n<<1)) limit<<=1,++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        ntt(limit,A,1);ntt(limit,B,1);
        for(int i=0;i<limit;++i) b[i]=A[i]*B[i]%p;
        ntt(limit,b,0);
        integral(b,limit);
        for(int i=0;i<limit;++i) A[i]=B[i]=0;
    }
    inline void exp(int pw,int *a,int *b)
    {
        if(pw==1) {b[0]=1;return;}
        exp((pw+1)>>1,a,b);ln(b,f,pw);
        int len=0,limit=1;
        while(limit<(pw<<1)) limit<<=1,++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        f[0]=(a[0]+1-f[0])%p;
        for(int i=1;i<pw;++i) f[i]=(a[i]-f[i]+p)%p;
        ntt(limit,f,1);ntt(limit,b,1);
        for(int i=0;i<limit;++i) b[i]=b[i]*f[i]%p;
        ntt(limit,b,0);
        for(int i=pw;i<limit;++i) b[i]=f[i]=0;
    }
    inline void main()
    {
        n=read();
        for(int i=0;i<n;++i) a[i]=read();
        exp(n,a,b);
        for(int i=0;i<n;++i) printf("%lld ",b[i]);
    }
}
signed main()
{
    red::main();
    return 0;
}
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!