\(NTT\)

\(FFT\)使用复数单位根对\(DFT\)进行优化,\(NTT\)则使用了另外一种方式优化。
这种方式被称之为原根。
使用单位根时我们会进行大量的浮点计算,这不光让程序的运行时间大大增加,还会带来很大的进度误差。而原根则没有这样的问题。
除此之外\(NTT\)还解决了多项式乘法带模数的情况。

原根

\(a,p\)互素,\(p>1\)
对于\(a^n\equiv1(\mod p)\)最小的\(n\),我们称之为 \(a\)\(p\)的阶,记做\(δ_p(a)\)

原根

定义:

\(p\)是正整数,\(a\)是整数,若\(δ_p(a)\)等于\(\phi(a)\),则\(a\)为模\(p\)的 原根。

性质1:如果一个数字\(p\)有原根,那么它有\(\phi (\phi(p))\)个原根。
性质2:模\(p\)有原根的充要条件\(n=2,4,p^\alpha,p^{2\alpha}\)\(p\)为奇素数。
性质3若\(p\)为素数,假设\(g\)\(p\)的原根,那么\(g^i \mod p,(i<p)\) 唯一。

除此之外,\(FFT\)中单位根满足的所有性质原根也满足。
所以我们认为\(g^{\frac{p-1}{n}}\)等价于\(e^{\frac{-2\pi i}{n}}\)
\(NTT\)中,\(p\)通常取998244353。原根为3。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
#include <stack>
using namespace std;
#define ll long long

const int mod = 998244353, G = 3, Gi = 332748118;//这里的Gi是G的除法逆元
const ll N = 5e6+50;
ll n,m;
ll limit = 1;//二进制位数
ll L,R[N];//二进制位数、二进制翻转数组

ll f_pow(ll a,ll b){
    ll res = 1;
    while(b){
        if(b&1)res = res*a%mod;
        a = a*a%mod;
        b>>=1;
    }
    return res%mod;
}
ll a[N],b[N];
void ntt(ll *A,ll type){
    for(ll i = 0;i < limit;i++)if(i < R[i])swap(A[i],A[R[i]]);
    for(ll mid = 1;mid < limit;mid<<=1){
        ll wn = f_pow(G,(mod-1)/(mid<<1));//原根
        if(type == -1) wn = f_pow(wn,mod-2);
        for(ll len = mid<<1,pos = 0;pos < limit;pos+=len){
            ll w = 1;
            for(ll k = 0;k<mid;k++,w = w*wn%mod){
                //原根的操作与单位根类似
                ll x = A[pos+k],y = w*A[pos+k+mid]%mod;
                A[pos+k] = (x+y)%mod;
                A[pos+k+mid] = (x-y+mod)%mod;
            }
        }
    }
    if(type == 1)return ;
    //依然是除n,但是这里需要求逆元
    ll inv_lim = f_pow(limit,mod-2);
    for(ll i = 0;i < limit;i++) A[i] = A[i]*inv_lim%mod;
}
string p,q;
stack<ll> st;
int main() {
    cin>>p>>q;n = p.length()-1,m = q.length()-1;
    for(ll i = n;i >= 0;i--)a[i] = p[i]-'0';
    for(ll i = m;i >= 0;i--)b[i] = q[i]-'0';
    while(limit <= n+m)limit<<=1,L++;//长度
    for(ll i = 0;i < limit;i++){
        R[i] = (R[i>>1]>>1) | ((i&1)<<(L-1));
        //在原序列中i与i/2的关系是:i是i/2的左移
        //那么反转之后就需要右移,同时处理尾数
    }
    ntt(a,1);
    ntt(b,1);
    for(ll i = 0;i <= limit;i++)a[i] = a[i]*b[i];
    ntt(a,-1);//逆变换

    for(ll i = n+m;i > 0;i--){
        ll tmp = a[i];
        //cout<<a[i]<<endl;
        st.push(tmp%10);
        a[i-1] += tmp/10;
    }
    cout<<a[0];
    while(!st.empty()){
        cout<<st.top();
        st.pop();
    }
    return 0;
}