多项式算法1:FFT(快速傅里叶变换)
- 前言
- 前置技能
- 正文
前言
算法简介
快速傅里叶变换 (fast Fourier transform), 即利用计算机计算离散傅里叶变换(DFT)的高效、快速计算方法的统称,简称FFT。快速傅里叶变换是1965年由J.W.库利和T.W.图基提出的。采用这种算法能使计算机计算离散傅里叶变换所需要的乘法次数大为减少,特别是被变换的抽样点数N越多,FFT算法计算量的节省就越显著。——百度百科
其实是一个用于快速求解多项式之积的算法。
多项式:形如的式子,展开来就是
如果我们用,表示两个不同的多项式,现在要求解两个数的乘积,朴素求法是直接相乘,时间复杂度为(过程请参考高精乘法),当时会炸。
不过,当我们换一种方法来表示多项式后,结果会不会好一点呢。考虑用离散的点来表示这个函数,根据高斯消元,不难看出,用至少n个不同的点来表示一个n元函数才可以唯一确定这个n元函数。现假设我用点来表示n元多项式,用点来表示n元多项式。
不难看出(最高为次)可用点表示。
算法时间复杂度为,已经能达到我们的要求了,那是不是说明我们成功完成了任务了呢。不,那还早呢,我们要的是系数表示的多项式,而不是点值表达式,如果按照的方法来取的值,用秦九韶算法将系数多项式转为点值,用高斯消元将点值转为系数多项式,那么时间复杂度为甚至更高。
那么,我们有没有一个好的算法来优化这一过程呢,这就是我们今天要讨论的问题了,这就是传说中的FFT(快速傅里叶变换)。
前置技能
DFT与IDFT
首先,我们要提一下什么是离散傅里叶变换(DFT),离散傅里叶变换就是将一个系数表达式变为一个点值表达式;而将一个点值表达式变化为一个系数表达式称为离散傅里叶逆变换(IDFT)。FFT就是以巨大常数为代价快速进行DFT和IDFT的算法。
复数
这个算法首先要在选未知数上做文章,如果未知数的乘幂一直都是1,那计算就方便许多。但是这样的数我们只能找到1和-1,如何找到更多满足条件的数呢?
这就有赖于数系的扩充了,在初中阶段,数学老师教育我们,这个数是不存在的。但在虚数中,数学家用符号来表示。复数就是形如的所有数,当时复数就是我们熟知的实数。下面要用到复数,如果还有不懂可以自行搜索,我这里就不再深入探讨复数了。
矩阵
在往下看之前最好要了解矩阵的基本概念,以及单位矩阵,逆矩阵的定义,不懂的请自行上网问度娘。
欧拉公式
正文
对复数了解一点的人都知道,复数可以表示成平面直角坐标系上的一个点,如下图:
该点可表示为。
在上文中,我们试着用数字1到n带入多项式来表示一个n次多项式,这样计算非常费力,我们需要谨慎选好带入的数。根据经验,乘幂为1或-1时非常好算,考虑1和-1。当n较大时,实数已经不能满足我们的需求了,此时,代入一些复数或许可以满足我们的要求。
以下是用结构体实现复数的代码:
const double pi = acos(-1.0);
struct virt{
double r , i;
virt( double r = 0.0 , double i = 0.0 ) {
this->r = r;
this->i = i;
}
virt operator + ( const virt &x ) {
return virt( r + x.r , i + x.i );
}
virt operator - ( const virt &x ) {
return virt( r - x.r , i - x.i );
}
virt operator * ( const virt &x ) {
return virt( r * x.r - i * x.i , i * x.r + r * x.i );
}
//复数除法用不到,就不写了
};
单位复根
鉴于上图用坐标系表示复数,我们是否也可以找到一类复数,它们的乘幂都可能为1或-1?
答案是肯定的,如下图:
在该单位圆上的所有复数都符合条件。
复数的乘法在复平面中表现为辐角相加,模长相乘。利用这一点,我们可以轻易地处理出所有复数的次幂。
复数满足称作是次单位根,下图包含了所有的4次单位根(图中圆的半径是1)。
同样的,下图是所有的8次单位根。
由此我们不难发现单位根有如下性质:
FFT的具体过程
FFT就是将系数表示法转化成点值表示法相乘,再由点值表示法转化为系数表示法的过程,第一个过程叫做离散傅里叶变换(DFT),又称求值,第二个过程叫做离散傅里叶逆变换(IDFT),又称插值。
DFT详细过程
想要求出一个多项式的点值表示发,需要选出个数分别带入到多项式里面。带入一个数复杂度是,那么总复杂度是的,这是无法接受的,但我们可以给多项式代入来利用次单位根的性质来加速我们的运算。
设为偶次项的和,为奇次项的和,则:
因为:所以有:
也就是说,只要有了和的点值表示,就能在的时间算出的点值表示,对于当前层确定的位置,就可以用下一层的两个值更新当前的值,我们称这个操作为“蝴蝶变换”。
因为这个过程一定要求每层都可以分成两大小相等的部分,所以多项式最高次项一定是,如果不够,在后面补零即可。
由此我们不难得到递归的写法:
void DFT( virt F[] , int len ) {
if( len == 1 ) {
return;
}
virt *a0 = new virt[len / 2];
virt *a1 = new virt[len / 2];
for ( int i = 0 ; i < len ; i += 2 ) {
a0[i / 2] = a[i];
a1[i / 2] = a[i + 1];
}
DFT( a0 , len / 2 );
DFT( a1 , len / 2 );
virt wn( cos( 2 * pi / len ) , sin( 2 * pi / len ) );
virt w( 1 , 0 );
for ( int i = 0 ; i < ( len / 2 ) ; ++i ) {
F[i] = a0[i] + w * a1[i];
F[i + len / 2] = a0[i] - w * a1[i];
w = w * wn;
}
}
分析一下时间复杂度,的规模为,而其奇偶次项的规模都为,而合并时间为,可得,根据主定理,该问题时间复杂度为。
代码写出来了,但这个递归版的DFT常数巨大,要是IDFT再来一次根本无法接受,于是我们考虑迭代的解法。
优化的DFT
首先我们先拿为例子:
不难发现,最终结果为(二进制):
每组数刚好互为位二进制表示下的逆序字符串。显然可得 ,当,最后每组数也互为位二进制表示下的逆序字符串。
利用这个性质,处理前先交换每对这类数,可以直接从下向上用迭代模拟递归的过程。
接下来,我们要找到一个方法快速地求出每个数在位二进制表示下对应的数。
先贴上代码:
void pre( int bit ) {
for ( int i = 0 ; i < ( 1 << bit ) ; ++i ) {
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit - 1));
//cout << rev[i] << endl;
}
return;
}
只要这一行位运算,就可以在的时间内求出每个数在位二进制表示下颠倒过来后对应的数。
证明:假设此时处理的数二进制表示为,由于单调递增,算到时已经算过了,二进制表示为,
为,
为,
显然为
最终可得为,完全符合要求。
讲到这,迭代版优化的DFT已经呼之欲出了:
void DFT( virt F[] , int len ) {
virt tem;
for ( int i = 0 ; i < len ; ++i ) {//把递归的底层交换好
if ( i < rev[i] ) {
tem = F[i];
F[i] = F[rev[i]];
F[rev[i]] = tem;
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {//枚举步长,从递归的下面往上走
virt wn( cos( 2 * pi / i ) , sin( 2 * pi / i ) );//2pi为一个圆周
for ( int j = 0 ; j <= len - 1 ; j += i ) {//走一遍步长
virt w( 1 , 0 );
for ( int k = j ; k < j + i / 2 ; ++k ) {//枚举每块区间内的每一个元素
virt u = F[k];
virt v = w * F[k + i / 2];
F[k] = u + v;
F[k + i / 2] = u - v;
w = w * wn;
}
}
}
return;
}
IDFT
要想由点值表示法转化为系数表示法,暴力求也是。直观上似乎也很难想到一个更好的方法来进行插值求解的过程。不过,我们先看看DFT转换用矩阵乘法表示的过程,看看能不能发现什么头绪:
从这式子不难看出,只要找出
的逆矩阵,就可以用跟DFT一样的方法来实现IDFT了。
可以证出只要把每个换成,即虚部取负,然后再除以即可得到逆矩阵。
证明如下:
如上矩阵,
逆变换为:
式子
展开的求和式得
由等比数列的求和公式得:
当,可得:
当,可得:
因而:
该矩阵的逆矩阵得证。
贴上DFT与IDFT结合起来的代码:
void FFT( virt F[] , int len , int on ) {
virt tem;
for ( int i = 0 ; i < len ; ++i ) {//把递归的底层交换好
if ( i < rev[i] ) {
tem = F[i];
F[i] = F[rev[i]];
F[rev[i]] = tem;
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {//枚举步长,从递归的下面往上走
virt wn( cos( 2 * pi / i ) , sin( on * 2 * pi / i ) );//2pi为一个圆周
//一般习惯在函数外面除以len
for ( int j = 0 ; j <= len - 1 ; j += i ) {//走一遍步长
virt w( 1 , 0 );
for ( int k = j ; k < j + i / 2 ; ++k ) {//枚举每块区间内的每一个元素
virt u = F[k];
virt v = w * F[k + i / 2];
F[k] = u + v;
F[k + i / 2] = u - v;
w = w * wn;
}
}
}
return;
}
至此我们终于学会了FFT \ ^ o ^ /。
贴上一道模板题的完整代码:
FFT的模板题
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#define N 4097153
using namespace std;
const double pi = acos(-1.0);
struct virt{
double r , i;
virt( double r = 0.0 , double i = 0.0 ) {
this->r = r;
this->i = i;
}
virt operator + ( const virt &x ) {
return virt( r + x.r , i + x.i );
}
virt operator - ( const virt &x ) {
return virt( r - x.r , i - x.i );
}
virt operator * ( const virt &x ) {
return virt( r * x.r - i * x.i , i * x.r + r * x.i );
}
};
virt a[N] , b[N];
int rev[N];
int n , m , l;
void pre( int bit ) {
for ( int i = 0 ; i < ( 1 << bit ) ; ++i ) {
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit - 1));
}
return;
}
void FFT( virt F[] , int len , int on ) {
virt tem;
for ( int i = 0 ; i < len ; ++i ) {//把递归的底层交换好
if ( i < rev[i] ) {
tem = F[i];
F[i] = F[rev[i]];
F[rev[i]] = tem;
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {//枚举步长,从递归的下面往上走
virt wn( cos( 2 * pi / i ) , sin( on * 2 * pi / i ) );//2pi为一个圆周
for ( int j = 0 ; j <= len - 1 ; j += i ) {//走一遍步长
virt w( 1 , 0 );
for ( int k = j ; k < j + i / 2 ; ++k ) {//枚举每块区间内的每一个元素
virt u = F[k];
virt v = w * F[k + i / 2];
F[k] = u + v;
F[k + i / 2] = u - v;
w = w * wn;
}
}
}
return;
}
int main () {
int len = 0;
scanf("%d",&n);
scanf("%d",&m);
for ( int i = 0 ; i <= n ; ++i ) {
scanf("%lf",&a[i].r);
}
for ( int i = 0 ; i <= m ; ++i ) {
scanf("%lf",&b[i].r);
}
len = n + m;
int tim = 1;
l = 0;
while( tim <= len ) {
tim <<= 1;
l++;
}
len = tim;
pre( l );
FFT( a , len , 1 );
FFT( b , len , 1 );
for ( int i = 0 ; i <= len - 1 ; ++i ) {
a[i] = a[i] * b[i];
}
FFT( a , len , -1 );
for ( int i = 0 ; i <= n + m ; ++i ) {
printf("%d ",(int)(a[i].r / len + 0.5));//四舍五入
}
return 0;
}
关于精度
由于FFT进行了大量的浮点数运算,我们要预处理所有要用到的单位复根及其幂(用三角函数式计算),这样可以保证精度。目前越来越多的出题人开始用这一点来卡人,只要不预处理就会爆精度。
根据之前的定义,不难得出(为次单位根):
这样虽然比上面那个方法慢,但是精度误差更小,可以满足精度的要求,防止出题人卡精度
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#define N 4097153
using namespace std;
const double pi = acos(-1.0);
struct virt{
double r , i;
virt( double r = 0.0 , double i = 0.0 ) {
this->r = r;
this->i = i;
}
virt operator + ( const virt &x ) {
return virt( r + x.r , i + x.i );
}
virt operator - ( const virt &x ) {
return virt( r - x.r , i - x.i );
}
virt operator * ( const virt &x ) {
return virt( r * x.r - i * x.i , i * x.r + r * x.i );
}
};
virt a[N] , b[N] , w0[N] , w1[N];
int rev[N];
int n , m , l;
void pre( int bit ) {
for ( int i = 0 ; i < ( 1 << bit ) ; ++i ) {
rev[i] = (rev[i>>1]>>1)|((i&1)<<(bit - 1));
}
return;
}
void FFT( virt F[] , int len , int on ) {
virt tem;
for ( int i = 0 ; i < len ; ++i ) {//把递归的底层交换好
if ( i < rev[i] ) {
tem = F[i];
F[i] = F[rev[i]];
F[rev[i]] = tem;
}
}
for ( int i = 2 ; i <= len ; i <<= 1 ) {//枚举步长,从递归的下面往上走
for ( int j = 0 ; j <= len - 1 ; j += i ) {//走一遍步长
for ( int k = j ; k < j + i / 2 ; ++k ) {//枚举每块区间内的每一个元素
virt w;
if( on == 1 ) {
w = w0[i / 2 + k - j];
} else {
w = w1[i / 2 + k - j];
}
virt u = F[k];
virt v = w * F[k + i / 2];
F[k] = u + v;
F[k + i / 2] = u - v;
}
}
}
return;
}
int main () {
int len = 0;
scanf("%d",&n);
scanf("%d",&m);
for ( int i = 0 ; i <= n ; ++i ) {
scanf("%lf",&a[i].r);
}
for ( int i = 0 ; i <= m ; ++i ) {
scanf("%lf",&b[i].r);
}
len = n + m;
int tim = 1;
l = 0;
while( tim <= len ) {
tim <<= 1;
l++;
}
len = tim;
for ( int j = 1 ; j < len ; j <<= 1 ) {
for ( int i = j ; i <= ( j << 1 ) - 1 ; ++i ) {
w0[i] = virt( cos( pi / j * ( i - j ) ) , sin( pi / j * ( i - j ) ) );
w1[i] = virt( cos( pi / j * ( i - j ) ) , sin( -1 * pi / j * ( i - j ) ) );
}
}
pre( l );
FFT( a , len , 1 );
FFT( b , len , 1 );
for ( int i = 0 ; i <= len - 1 ; ++i ) {
a[i] = a[i] * b[i];
}
FFT( a , len , -1 );
for ( int i = 0 ; i <= n + m ; ++i ) {
printf("%d ",(int)(a[i].r / len + 0.5));//四舍五入
}
return 0;
}
2022.5.17修改了一下,可以过,但运行效率上确实更慢。