多项式求逆
多项式求逆指的是给定一个多项式F(x),求出一个多项式G(x)满足
F(x)∗G(x)≡1(modxn)
它是怎么做的?
我们称一个多项式的“度”为其最高次项系数+1
首先,我们知道当n=1的时候,显然G(x)即为F(x)的常数项之逆元
我们将原式写成模x⌈2n⌉意义下的形式:
F(x)∗G(x)≡1(modx⌈2n⌉)
假设我们已经求出B(x)满足
F(x)∗B(x)≡1(modx⌈2n⌉)
将两个式子相减
G(x)−B(x)≡0(modx⌈2n⌉)
平方一下
G2(x)−2G(x)B(x)+B2(x)≡0(modxn)
两边乘上F(x)
G(x)−2B(x)+F(x)B2(x)≡0(modxn)
(这里由于F(x)∗G(x)≡1(modxn),消去了一些部分)
移项整理得
G(x)≡(2−F(x)B(x))B(x)(modxn)
多项式乘法可以用FFT/NTT加速
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| #include <cstdio> #include <cstdlib> #include <cmath> #include <algorithm> #include <cstring> #include <iostream> #define inv(x) (fastpow((x),mod-2)) using namespace std; typedef long long ll;
template <typename T>void read(T &t) { t=0;int f=0;char c=getchar(); while(!isdigit(c)){f|=c=='-';c=getchar();} while(isdigit(c)){t=t*10+c-'0';c=getchar();} if(f)t=-t; }
const ll mod=998244353,gg=3,ig=332748118; const int maxn=100000+5; int n; ll a[maxn<<2],b[maxn<<2];
ll fastpow(ll a,ll b) { ll re=1,base=a; while(b) { if(b&1) re=re*base%mod; base=base*base%mod; b>>=1; } return re; }
int len; int r[maxn<<2]; void NTT(ll *f,int type) { for(register int i=0;i<len;++i) if(i<r[i]) swap(f[i],f[r[i]]); for(register int p=2;p<=len;p<<=1) { int length=p>>1; ll unr=fastpow(type?gg:ig,(mod-1)/p); for(register int l=0;l<len;l+=p) { ll w=1; for(register int i=l;i<l+length;++i,w=w*unr%mod) { ll tt=f[i+length]*w%mod; f[i+length]=(f[i]-tt+mod)%mod; f[i]=(f[i]+tt)%mod; } } } if(!type) { ll ilen=inv(len); for(register int i=0;i<len;++i) f[i]=f[i]*ilen%mod; } }
ll c[maxn<<2]; void getinv(int deg,ll *a,ll *b) { if(deg==1) { b[0]=inv(a[0]); return; } getinv((deg+1)>>1,a,b); for(len=1;len<=(deg<<1);len<<=1); for(register int i=0;i<len;++i) { r[i]=(r[i>>1]>>1)|((i&1)?len>>1:0); c[i]=(i<deg?a[i]:0); } NTT(c,1),NTT(b,1); for(register int i=0;i<len;++i) b[i]=(2ll-c[i]*b[i]%mod+mod)%mod*b[i]%mod; NTT(b,0); fill(b+deg,b+len,0);//重要,因为是在模 x^deg 意义下 }
int main() { read(n); for(register int i=0;i<n;++i)read(a[i]); getinv(n,a,b); for(register int i=0;i<n;++i) printf("%lld ",b[i]); return 0; }
|