Yan`s Notepad

--- My Notepad......Articles, tools and etc.
清晰易懂的FFT实现
在开始之前,希望读者明白复数基本的知识(加减乘除,模长),然后还有欧拉恒等式,当然,基本的三角函数也得明白,更进一步,对FFT有一定的认识,至少要明白它是干什么的。都不知道?那还看这个干什么啊,速速右转。关于欧拉恒等式,3blue1brown有一个很有趣的视频在4分钟不到的时间内讲了它,有令人恍然大悟的感觉。如果可能,非常建议去看看。当然,这篇文章不会讲傅里叶变换的原理。看标题也应该明白:讲的是FFT的那个算法的实现。
路线图:推导数学式子 -> 递归版本的FFT实现 -> 迭代版的FFT实现 -> 对迭代版本FFT进行优化 -> 蝶形图 -> 最终的FFT代码。如果可能,不要跳过任何部分,请耐心看完。
下面切入正题——快速傅里叶变换(Fast Fourier Transform),也就是通俗的FFT。FFT的版本有很多种(除了下文的这个基-2(Radix-2)以外,还有基-3(Radix-3),基-4(Radix-4),分裂基,矢量基,威诺格拉德短DFT,稀疏矩阵分解......等等的实现,总而言之不局限于下面的一种)。具体而言,基-n的算法都是差不多的,同时也是最简单的。因此,将用基-2的FFT(准确而言,它叫做DIT FFT。下面文章中简称的FFT都是这个)为例子,说明它的实现过程。DFT的分解过程的数学过程并不难懂,它只是在拆分求和项,不过式子长得吓人,不要被这些看着很大的式子吓跑了。因此,我们依然有必要来分解一下它。全文用i表示虚数单位,补充一下旋转因子Wnk的定义:
small
基-2的FFT用于计算一个N点的序列,N是2的整数幂。它将这个序列分解为两个N/2点的小序列,一个包含了下标为偶数的序列元素,一个包含了下标为奇数的序列元素。它计算这两个小序列的DFT得到原始的DFT。简单来说,我们有一个DFT的式子:
small
自然的,也有一个IDFT式子(下图)。它们两个很像,除了前面乘了个$\frac{1}{N}$外,差异便是旋转因子Wnk符号改了。这意味着我们所作出的FFT需要变成IFFT也只需要两步:旋转因子更改符号,然后把结果乘以$\frac{1}{N}$。因此,我们只需要讨论FFT即可,而IFFT只需要在FFT上做一点点修改,其余都是一样的。
small
下面, 我们把求和项拆开,把它表示为一个奇数的求和项和一个偶数的求和项,得到:
small
最后的那个式子就是拆分完成后的奇数项(右)和偶数项(左),可以看到,它们两个非常相似。而我们把nk提出,又可以得到:
small
注意,按照旋转因子的定义,括号内的项可以等于:
small
因此,我们可以得到最初的N点DFT拆分后的式子:
small
这下,FFT的过程就清晰了:它将DFT拆分为两个小的DFT,这两个小的DFT自然可以继续拆分为更小的DFT...如此往下,拆分到只有2个点的DFT,进行运算,然后倒回去。进而得到结果。而基-n的FFT相似地方就在于此,我们可以把它拆分为3项,也将发现类似的规律——这是基-3的FFT;拆分为4项,也会有类似的规律——这是基-4的FFT......注意,因为它要一直拆分到两个点的DFT,这意味着基2的FFT的点数是2的整数幂。同理,基-3的FFT点数应当是3的整数幂。为了方便表示,用G和F表示了这两个N/2点的小DFT的结果,G表示的变换是对应于偶数项的,F表示的变换是对应于奇数项的:
small
考虑到旋转因子在复平面上所对应的那个单位圆,很显然,对于N点的DFT,它的周期为N:X(k+N)==X(k),对应的,N/2点的DFT周期为N/2。不过,我们应该对N点的DFT的k+N/2项感兴趣——它刚好是关于原点对称的,直觉说明,Wn(k+n/2)将会与Wn(k)反向:
small
这样一切都准备就绪。我们来用8个点的数据,来体会一下它的计算过程。
我们从8个点出发,第一次拆分为两个4点的,然后变成4个两点的,最后变成8个1点的。对于1点的数据做DFT没什么意思,直接返回的是它本身。而2点的、4点的、8点的等,都可以由上一组数据得到答案。把图画出来,它看起来像这样(图中省略了运算,因为合并过程有点复杂):
···图片正在加载···
images/radixnFFT/fftproc.png
FFT的拆分与运算
怎么样, 是不是像极了归并排序呢?它们都可以用分治法实现啊
拆分就不多说了。每次合并时, 结果都按照如下的规则进行,其中的N是C的DFT点数,而A, B则是拆分下来的小DFT,A是它的偶数部分,B是它的奇数部分
small
同样的,我们先把每个序列持续拆分,直到只有一个数为止。因为数据是按照奇偶项拆分的,所以我们需要准备一个数组,用来顺序保存这些拆分后的小序列。比如在拆分的第一步中,这个数组就会存放$x_0, x_2, x_4, x_6, x_1, x_3, x_5, x_7$。拆分的第二步里面会存放x(0), x(4), x(2), x(6), x(1), x(5), x(3), x(7)等。然后按照顺序把它们合并,即可得到结果:
void fft_rec(complex<double> *dat, int n)
{
    complex<double>  tmp[FFT_NUM];                      // 存放拆分时候的数据
    int              mid, i, k;
    if (n == 1)
        return;
    mid = n / 2;
    for (i = 0; i < mid; i++)                           // 按照先偶数项后奇数项的顺序, 把ffdat放入tmp数组
    {
        tmp[i] = dat[i * 2];                            // 偶数项目
        tmp[i + mid] = dat[i * 2 + 1];                  // 奇数项目
    }
    for (i = 0; i < n; i++)                             // 把这些数据拷贝回原始数组
    {                                                   // 此时, dat存放着拆分的数据
        dat[i] = tmp[i];
    }
    fft_rec(dat, mid);                                  // 递归拆分它们
    fft_rec(dat + mid, mid);                            // 不合并的话, 就可以得到拆分结果了
    for (k = 0; k < mid; k++)                           // 下面开始合并数据, 把上次拆分的两个小DFT合并成大的DFT
    {                                                   // 而dat则存放着拆分后的数据
        complex<double>  wnk;
        WnkCalc(&wnk, n, k);                            // 计算旋转因子
        tmp[k] = dat[k] + wnk * dat[k + mid];           // 0到N/2的数据, dat[k]是偶数部分, dat[k+mid]是奇数部分
        tmp[k + mid] = dat[k] - wnk * dat[k + mid];     // N/2到N的数据
    }
    for (i = 0; i < n; i++)                             // 把结果送回dat
    {
        dat[i] = tmp[i];
    }
}
其中用到了C++的std命名空间中的复数模板库。复数运算啥的,可以自己实现,本身并不复杂。
在递归版的FFT中,应该是可以很清晰的看到整个合并规则了。很显然,前面的递归部分是一直在拆分数组,让它得到特定是顺序;然后在从底往上的去合并数组。换种说法,拆分的次数是已知的,它就是log(2,N),若是我们现在有一个拆分后的序列,然后对它再一步步向上合并,自然也可以得到正确的答案。而拆分后的序列顺序也有一定规律,虽然难以发现:
最大的数是7,它的二进制是111,也就是3位的。现在考虑几个数:
最开始序列的第1个位置是x(1),而拆分后(过程图的中间只剩下1个元素的时候,就是拆分好了),变成了x(4)。1的二进制是001,4的二进制是100。很明显:它们的二进制位是相互颠倒的。
再随意找一个,比如第3个位置是x(3),拆分后变成了x(6),3的二进制是011,6的二进制是110。它们的二进制位依旧是相互颠倒的。
这就好办了——我们只需要按照二进制位颠倒的这种顺序,去重新排列这个输入即可得到拆分后的序列顺序:
// 计算n的整数幂
int ilog2(int n)
{
    int  r = 0;
    while (n > 0)
    {
        n >>= 1;                                        // 2的整数幂都是0000 1 00000这样的二进制形式
        r++;
    }
    return r - 1;                                       // 2^0是1而不是0, 因此需要减去1位
}
// 反比特顺序
void bitRev(complex<double> *dat, int n)
{
    int  i, j, r, index;
    index = ilog2(n);                                   // 得到它是2的几次方
    for (i = 0; i < n; i++)
    {
        r = 0;
        for (j = 0; j < index; j++)
        {
            int value = 1 << j;
            if ((i & value) == value)                   // 发现第j+1位是1(因为j是从0开始的...)
                r |= 1 << (index - j - 1);              // 对应输出的颠倒过来的那个位设置为1
        }
        if (i < r)                                      // i < r才处理, 不然会重复交换, 等于白做
            swap(dat[i], dat[r]);
    }
}
上面的这种事情叫做反比特顺序。
反比特顺序之后,我们需要处理的DFT数目也可以从那个过程图里面发现:刚开始被拆分为2^0也就是1那么大,需要得到2点的DFT,然后需要得到4点的......因而,处理的DFT点数是2^1, 2^2...这样的规律,直到最后需要得到2^3即8点的DFT大小。然后我们再考虑当前合并的DFT数目。同样的,按照那个过程图,最开始的时候是单个序列,需要合并4次,然后变成了4个2点的DFT,需要合并2次变成2个4点的DFT,最后需要合并1次变成1个8点的DFT,就是结果。很显然,合并次数依然是2^2, 2^1....这样的规律。
如此说来,迭代版本的DFT便很好实现:最外层循环处理需要进行几层的运算,它的循环次数是log(2, N),下一层处理每一个小的DFT,要求把这mergeN个DFT合并。最内层循环合并这些大小为curN/2的DFT,让它变成大小为curN的DFT。其中,数据依然是按照之前递归版本的那样,先偶数部分后奇数部分的排列:
// 迭代版的FFT, 数目当然是FFT_NUM
void fft_iter(complex<double> *dat)
{
    int              lay, index, curN, mergeN, j, k;
    int              beg, midN;
    complex<double>  tmp[FFT_NUM], wnk;
    bitRev(dat, FFT_NUM);                               // 对数据进行反比特操作
    index = ilog2(FFT_NUM);
    for (lay = 1; lay <= index; lay++)                  // 我们从最底层开始计算DFT, 一共存在着log(2,N)那么多层
    {
        curN = 1 << lay;                                // 当前合并的DFT大小
        midN = curN >> 1;                               // 被合并的DFT大小, 它是合并结果大小curN的一半
        mergeN = 1 << (index - lay);                    // 需要合并的次数
        for (j = 0; j < mergeN; j++)                    // 合并每一个小的DFT
        {
            beg = j * curN;                             // 合并的这些序列的开始位置, 针对于dat而言的
            for (k = 0; k < midN; k++)                  // 处理这些小DFT
            {
                WnkCalc(&wnk, curN, k);                 // 计算旋转因子
                tmp[k] = dat[beg + k] + wnk * dat[beg + k + midN];
                tmp[k + midN] = dat[beg + k] - wnk * dat[beg + k + midN];
            }
            for (k = 0; k < curN; k++)                  // 把结果拷贝回dat
            {
                dat[beg + k] = tmp[k];
            }
        }
    }
}
下面来对这个迭代版本的FFT进行优化——毕竟直觉也能感受到,它还是比较慢的。为了方便处理,我们先不考虑Wnk计算时的开销。首先,我们考虑消除那个数组tmp,隐约的直觉告诉我们,这个数组是多余的——因为dat里面是有原始数据的,而这个数组的内容每次都扔到了dat中。如考虑tmp的作用,它是为了存放两个小DFT的合并结果而存在:
tmp[k] = dat[beg + k] + wnk * dat[beg + k + midN];
tmp[k + midN] = dat[beg + k] - wnk * dat[beg + k + midN];
如果左边的tmp换成对应的dat数据的话,那么结果很明显——偶数部分的DFT倒是合并了,但是奇数部分的DFT合并就出毛病,因为它修改了dat[beg + k]。于是,我们把一些部分移出来:
complex<double>  t;
t = wnk * dat[beg + k + midN];
dat[beg + k + midN] = dat[beg + k] - t;
dat[beg + k] = dat[beg + k] + t;
先去修改dat[beg + k + midN],这样就不会影响到奇数部分合并, 而偶数部分不使用它。因此,这个影响消除,我们可以放心的移走tmp数组和下面的那个拷贝数据用的循环。
这个操作有着令人不解的名字——蝴蝶操作,因为我怎么看都不觉得它所对应的那个蝶形图长得像蝴蝶。作为一个会编写FFT程序的人,你应该了解它,好方便与朋友吹牛,当然,不看也罢,毕竟已经知道了这个蝶形图所带来的优化的样子。我们先回到那个合并规则的图:
small
下面,我们使用蝶形图表示出这个操作:
small
它像蝴蝶吗...
或许还是不解。红色的数字是我标的。事实上,蝶形图上的这些箭头圆圈表示了数据的流向。在标有1的位置,送过来的箭头分裂成两个箭头——这意味着数据在这里分为两组,这两组互不干扰;在标有2的位置,两组箭头发生了合并,然后送出,这表明这两组数据将在这里求和,结果到送出的那条线上;在标有3的位置,一些值出现在了线的附加,这表明这条线上的数据会与这个值相乘,而结果会替换掉这条线上原来的数据。因此,下面的这条蓝色区域,就表示了这样的式子:
small
可以发现,B(k)都是乘了旋转因子的,我们可以把上面的图简化为这样:
small
这个就是我们所作的那个优化的蝶形图表示。仔细想想是不是呢?它先和B(k)做了一次乘法,然后把值分配到两端(注意,分开的线已经表明了这两个值不会相互影响),也就是代码中的t。而后,在C(k)边上的那个圆,它与A(k)求和,这个就是优化的那段代码的第4行;而下面,它乘以-1而后与A(k)求和,这是就是优化的那段代码的第3行。现在,来尝试画出整个FFT的蝶形图吧!在此之前,一定要理解了上面的图的意思。
数据是送入迭代版本FFT的那个数据,也就是之前的过程图开始合并的那里。第一次合并,我们需要对数据流进行4次小的2点DFT操作,把这四次小的DFT操作画出来,按照小的蝶形图,我们给对应的数据乘上小的旋转因子,在相应位置乘上-1,之后也是一样的:
···图片正在加载···
images/radixnFFT/butterfly_fft0.png
第一层的DFT
接着,我们在第一层后面接上第二层,第二层的DFT数目是4,进行两次DFT,因此得到:
···图片正在加载···
images/radixnFFT/butterfly_fft1.png
第二层的DFT
最后,加上第三层,它的DFT数目是8,进行1次:
···图片正在加载···
images/radixnFFT/butterfly_fft2.png
8点FFT的蝶形图
这个,就是在各个大佬的文章里面出现的蝶形图。
经过我们之前的优化,FFT的代码已经快了很多。但是我们知道,本身而言,三角函数的求值是很花时间的,我们应该尽可能的减少需要求值的次数——这样,我们继续优化。更改迭代的策略。在先前的代码中,我们在内层循环合并两个小DFT的每一个数据,而在外层循环中迭代每个小DFT。这样导致我们每次都需要在内层循环里面计算旋转因子Wnk。我们会自然的产生一个思路,可否再最外层循环里面处理小DFT中的每个数据,而在内层循环中处理每一个DFT对应位置的数据。
很显然,无论是从过程图上,还是从蝶形图看,这些小DFT都是互不干扰的,而且,它们大小都是一样的,因此该方案应当是可行的。这样,我们就减少了三角函数的计算次数。这之后,内层的循环会变成这样:
for (k = 0; k < midN; k++)
{
    WnkCalc(&wnk, curN, k);
    for (j = 0; j < mergeN; j++)
    {
        beg = j * curN;
        // do something
    }
}
我们继续尝试优化。这次的目标是减少不必要的运算。
考虑循环内的那个beg = j * curN,它是用来计算dat数据的下标用的。经过我们更改的循环顺序,内层循环的任务是处理每个小DFT的第k点数据。很明显,第1个DFT的第k个数据就是从k开始的,而第2个DFT的第k个数据则离第一个有当前的DFT大小那么大。持续到第n个,循环最终会处理到N点DFT的数据下标那。因此,我们可以尝试把j变成数组的下标。这步优化后,整个FFT的程序应该是非常高效了,最终的代码如下:
void fft_iter(complex<double> *dat)
{
    int              lay, index, curN, mergeN, j, k;
    int              midN;
    complex<double>  wnk, t;
    bitRev(dat, FFT_NUM);                               // 对数据进行反比特操作
    index = ilog2(FFT_NUM);
    for (lay = 1; lay <= index; lay++)                  // 从最底层开始计算DFT, 一共存在着log(2,N)层
    {
        curN = 1 << lay;                                // 当前合并的DFT大小
        midN = curN >> 1;                               // 被合并的DFT大小, 它是合并结果大小curN的一半
        mergeN = 1 << (index - lay);                    // 需要合并的次数
        for (k = 0; k < midN; k++)                      // 处理这些小DFT中的数据
        {
            WnkCalc(&wnk, curN, k);                     // 计算旋转因子
            for (j = k; j < FFT_NUM; j += curN)         // 处理每一个DFT相同的数据点
            {
                t = wnk * dat[j + midN];
                dat[j + midN] = dat[j] - t;
                dat[j] = dat[j] + t;
            }
        }
    }
}
当然,将三角函数变为查表,还可以变得更高效。更改Wnk计算为查表的完整FFT程序如下:
#include <stdio.h>
#include <iostream>
#include <complex>
#define  FFT_NUM         8
#define  PI2             6.28318530717958647692528676655900576839433
#define  MAX_TEMP        24
using namespace std;
complex<double>          wnkTable[FFT_NUM];
complex<double>          ffdat[FFT_NUM];
// 计算n的整数幂
int ilog2(int n)
{
    int  r = 0;
    while (n > 0)
    {
        n >>= 1;                                        // 2的整数幂都是0000 1 00000这样的二进制形式
        r++;
    }
    return r - 1;                                       // 2^0是1而不是0, 因此需要减去1位
}
// 反比特顺序
void bitRev(complex<double> *dat, int n)
{
    int  i, j, r, index;
    index = ilog2(n);                                   // 得到它是2的几次方
    for (i = 0; i < n; i++)
    {
        r = 0;
        for (j = 0; j < index; j++)
        {
            int value = 1 << j;
            if ((i & value) == value)                   // 发现第j+1位是1(因为j是从0开始的...)
                r |= 1 << (index - j - 1);              // 对应输出的颠倒过来的那个位设置为1
        }
        if (i < r)                                      // i < r才处理, 不然会重复交换, 等于白做
            swap(dat[i], dat[r]);
    }
}
// FFT_NUM的FFT
// Tips: 使用IFFT, 将wnk的虚数部分更改为相反数, 即可得到wn(-k)
void fft_iter(complex<double> *dat)
{
    int              lay, index, curN, mergeN, j, k;
    int              midN;
    complex<double>  wnk, t;
    bitRev(dat, FFT_NUM);                               // 对数据进行反比特操作
    index = ilog2(FFT_NUM);
    for (lay = 1; lay <= index; lay++)                  // 从最底层开始计算DFT, 一共存在着log(2,N)层
    {
        curN = 1 << lay;                                // 当前合并的DFT大小
        midN = curN >> 1;                               // 被合并的DFT大小, 它是合并结果大小curN的一半
        mergeN = 1 << (index - lay);                    // 需要合并的次数
        for (k = 0; k < midN; k++)                      // 处理这些小DFT中的数据
        {
            wnk = wnkTable[(FFT_NUM / curN) * k];       // 旋转因子, 它在单位圆上转圈圈, FFT_NUM是curN的整数倍, 因此这样写
            for (j = k; j < FFT_NUM; j += curN)         // 处理每一个DFT相同的数据点
            {
                t = wnk * dat[j + midN];
                dat[j + midN] = dat[j] - t;
                dat[j] = dat[j] + t;
            }
        }
    }
}
// 旋转因子wnk
void WnkCalc(complex<double> *o, int n, int k)
{
    double         d = (PI2 / n) * k;                   // k(2Pi/N)
    o->_Val[0] = cos(d);                                // o.real = cos(k(2Pi/N))
    o->_Val[1] = sin(d);                                // o.imag = sin(k(2Pi/N))
}

int main()
{
    int     i;
    double  r;
    for (i = 0; i < FFT_NUM; i++)
    {
        WnkCalc(wnkTable + i, FFT_NUM, i);              // 初始化, 计算Wnk的表

        r = PI2 / FFT_NUM;                              // 2Pi / N
        ffdat[i] = cos(i * r) + sin(2.0 * i * r);

        printf("inp[%d] = %lf + %lfi\n", i, ffdat[i].real(), ffdat[i].imag());
    }

    fft_iter(ffdat);

    for (i = 0; i < FFT_NUM; i++)
    {
        printf("fft[%d] = %lf + %lfi\n", i, ffdat[i].real(), ffdat[i].imag());
    }
    return 0;
}
我们来考虑一下它们各自的频域分辨率。
对于最开始的N点DFT,它的频域分辨率很容易得到:假设采样间隔为T,则为:
small
拆分之后,变成了两个N/2点的DFT,但是采样间隔变成了2T,因此:
small
可以看到,在拆分前后,它们的频域分辨率是不改变的。而时域分辨率由于之前的反比特顺序发生了更改。因而,它被叫做时域抽取法(Decimation In Time,DIT)。自然的,会有频域抽取法(Decimation In Frequency,DIF)的基-2的FFT。其蝶形图与文中所说的又是有差别的,是另外一种算法。若读者对上述Radix-2 FFT算法的DIF版本感兴趣,可自行了解。
在最后,我在这里引入主方法,并表明FFT算法是$O\left(nlogn\right)$的。主方法的证明可以自行去了解。在这里我们使用一个缩小版本——尽管这样说,它依然足以干掉绝大部分的递归或分治。
首先,我们有a个递归调用,每个递归调用将问题拆分为原始的1/b那么大,然后d表示在此之外需要进行的操作情况,也就是在此之外需要$O\left(n^d\right)$。比如这个操作是常数级的,那么d=0。注意,a, b, d都是和n无关的。此时,它的时间复杂度为:
$$T(n)= \begin{cases} O\left(n^dlogn\right) & \text{if} \qquad a=b^d\\\\ O\left(n^d\right) & \text{if} \qquad a<b^d\\\\ O\left(n^{log_ba}\right) & \text{if} \qquad a>b^d \end{cases}$$
主方法的证明可以在别处找到,这里不再说明。
以本文的基-2类FFT为例子,考虑递归版本的——因为迭代版只是对递归版的展开和改进。它有a=2次递归调用,每次都把问题拆解为b=2个子问题,在此之外,有一个复杂度是$O\left(n\right)$的几个小程序负责拷贝数组,因此d=1。很显然,它符合主方法的第一种情况:$a=b^d$,所以时间复杂度为$O\left(nlogn\right)$