FFT

Table of Contents

1. FFT

1.1. dft

  1. dft 的计算公式为: \(X_k=\frac{2}{N}\sum_{n=0}^{N-1}x_n\cdot e^{-i*2\pi k\frac{n}{N}}\), \(k \in 0 \to N-1\), k 表示不同频率
  2. 其中的 \(e^{i\omega}\) 根据欧拉公式表示 \(cos(\omega)+i\cdot sin(\omega)\), 所以 dft 也可以写成 \(X_k=\frac{2}{N}\sum_{n=0}^{N-1}x_n\cdot(cos(-2\pi k\frac{n}{N})+i\cdot sin(-2\pi k \frac{n}{N})))\),

    这里之所以用 \(-2\pi\) 而不是 \(2\pi\) 是有意让 sin 部分取负值, 后续计算 idft 会简单一点

  3. \(X_k\) 计算结果是复数的形式: \(X_k=amp\_of\_cos(k)+i \cdot amp\_of\_sin(k)\), 实际上表示两个数: cos(k) 和 sin(k) 的幅度
  4. N 称为 dft 的点数, 通常是 2 的 n 次幂, 例如 128, 因为后面 fft 算法要求 N 是 2 的 n 次幂.
  5. dft 简单的说就是 x 与不同频率的 cos (和 sin) 在 N 个采样点的逐位相乘后再累加, 做为对应 cos (和 sin) 的幅度. 如下图所示:

    • 红色是时域信号 x
    • 蓝色是对应 cos(k=2)
    • 绿色是对应 sin(k=5)
    • 竖的灰线表示 N = 7, 灰线与红线和蓝线有 7 组交点, 两两相乘再累加, 即是 cos(k=2) 的幅度

    dft.png

一个简单的 dft 的代码如下:

void my_dft(int n_point, kiss_fft_cpx *in, kiss_fft_cpx *out) {
    for (int i = 0; i < n_point; i++) {
        float tmp_r = 0.0f;
        float tmp_i = 0.0f;
        for (int k = 0; k < n_point; k++) {
            tmp_r += in[k].r * cos(-1 * i * 2.0 * PI * k / n_point);
            tmp_i += in[k].r * sin(-1 * i * 2.0 * PI * k / n_point);
        }
        out[i].r = tmp_r;
        out[i].i = tmp_i;
    }
}

1.2. idft

idft 与 dft 反过来:

\(x_n=\frac{1}{N}\sum_{k=0}^{N-1}X_k\cdot e^{i\cdot 2\pi k\frac{n}{N}}\)

按照 dft 的定义, x 是许多个频率不同的 cos 和 sin 的叠加, 所以计算 idft 时直接把所有频率 (k) 对应的 cos 和 sin 叠加起来即可.

由于 \(X_k\) 用复数表示, 所以展开后相当于 \(amp\_of\_cos(k)*cos(2\pi k \frac{n}{N})-amp\_of\_sin(k)*sin(2\pi k \frac{n}{N})\)

由于前面计算 dft 时 \(amp\_of\_sin(k)\) 实际是用的原值的对应的负值, 所以最终结果是正确的

简单的计算 idft 的代码:

void my_idft(int n_point, kiss_fft_cpx *in, kiss_fft_cpx *out) {
    for (int i = 0; i < n_point; i++) {
        float tmp = 0.0f;
        for (int k = 0; k < n_point; k++) {
            tmp += in[k].r * cos(2.0 * PI * i * k / n_point);
            /* 这里的 in[k].i 是负值, 所以用减号, 减号对应 i^2=-1 */
            tmp -= in[k].i * sin(2.0 * PI * i * k / n_point);
        }
        out[i].r = tmp;
        out[i].i = 0.0f;
    }
}

1.3. fft

fft 是计算 dft 的快速算法, 复杂度为 \(O(n\log(n))\) , 而前面提到的简单 dft 算法的复杂度是 \(O(n^2)\)

1.3.1. 对称性

对称性是指 \(X_{k+N}=X_k\), 推导过程为:

\(X_k = \sum_{n=0}^{N-1}{x_n\cdot e^{-i2\pi{kn/N}}}\)

\(X_{k+N} = \sum_{n=0}^{N-1}{x_n\cdot e^{-i2\pi{(k+N)n/N}}} = \sum_{n=0}^{N-1}{x_n\cdot e^{-i2\pi{n}}\cdot e^{-i2\pi{kn/N}}}\)

\(e^{-i2\pi{n}} = 1\)

\(X_{k+N}=X_k\)

1.3.2. 分治

把 N 个点按奇偶分成两部分:

\begin{eqnarray*} X_{k} &=& \sum_{n=0}^{N-1}{x_n\cdot e^{-i2\pi{kn/N}}} \\ &=& \sum_{m=0}^{N/2-1}{x_{2m}\cdot e^{-i2\pi{k(2m)/N}}} + \sum_{m=0}^{N/2-1}{x_{2m+1}\cdot e^{-i2\pi{k(2m+1)/N}}} \\ &=& \sum_{m=0}^{N/2-1}{x_{2m}\cdot e^{-i2\pi{km/(N/2)}}} + e^{-i2\pi{k/N}}\sum_{m=0}^{N/2-1}{x_{2m+1}\cdot e^{-i2\pi{km/(N/2)}}} \\ &=& E_k+factor*O_k \end{eqnarray*}

其中 \(E\) (even), \(O\) (odd) 是大小为 \(\frac{N}{2}\) 的 dft

对 \(E\) 和 \(O\) 来说, 由于 \(E_k=E_{k-\frac{N}{2}} \mid k>\frac{N}{2}\), 所以当计算 \(X_k \mid {k>\frac{N}{2}}\) 时, 可以复用前一次归并时计算好的 \(E_{k-\frac{N}{2}}\) 而不需要再重复计算 \(E_k\), 所以其复杂度是 \(O(n\log(n))\)

1.3.3. 递归实现

void my_recursive_fft(int n_point, kiss_fft_cpx *in, kiss_fft_cpx *out) {
    if (n_point == 1) {
        out[0].r = in[0].r;
        out[0].i = in[0].i;
        return;
    }
    kiss_fft_cpx *even_in =
        (kiss_fft_cpx *)KISS_FFT_MALLOC((n_point / 2) * sizeof(kiss_fft_cpx));
    kiss_fft_cpx *even_out =
        (kiss_fft_cpx *)KISS_FFT_MALLOC((n_point / 2) * sizeof(kiss_fft_cpx));
    kiss_fft_cpx *odd_in =
        (kiss_fft_cpx *)KISS_FFT_MALLOC((n_point / 2) * sizeof(kiss_fft_cpx));
    kiss_fft_cpx *odd_out =
        (kiss_fft_cpx *)KISS_FFT_MALLOC((n_point / 2) * sizeof(kiss_fft_cpx));

    for (int i = 0; i < n_point / 2; i++) {
        even_in[i] = in[i * 2];
        odd_in[i] = in[i * 2 + 1];
    }
    my_recursive_fft(n_point / 2, even_in, even_out);
    my_recursive_fft(n_point / 2, odd_in, odd_out);

    int mid = n_point / 2;
    for (int i = 0; i < mid; i++) {
        float twiddle_factor_r = cos(-2 * PI * i / n_point);
        float twiddle_factor_i = sin(-2 * PI * i / n_point);

        out[i].r = even_out[i].r + odd_out[i].r * twiddle_factor_r -
                   odd_out[i].i * twiddle_factor_i;
        out[i].i = even_out[i].i + odd_out[i].r * twiddle_factor_i +
                   odd_out[i].i * twiddle_factor_r;

        twiddle_factor_r = cos(-2 * PI * (i + mid) / n_point);
        twiddle_factor_i = sin(-2 * PI * (i + mid) / n_point);
        out[i + mid].r = even_out[i].r + odd_out[i].r * twiddle_factor_r -
                         odd_out[i].i * twiddle_factor_i;
        out[i + mid].i = even_out[i].i + odd_out[i].r * twiddle_factor_i +
                         odd_out[i].i * twiddle_factor_r;
    }
}

1.3.4. 非递归实现

由于分治时是以奇偶位置为准进行分组, 而不是像归并排序一样使用前半部/后半部, 导致分治时数据的分组会比较复杂.

可以事先交换数据的位置, 把它变成像归并排序一样使用前半部/后半部来分组, 以简化后续代码. 以 N = 8 为例:

0 1   2 3   4 5   6 7 
0 2   4 6 | 1 3   5 7 
0 4 | 2 6 | 1 5 | 3 7 

最终交换后的数据为 `0 4 2 6 1 5 3 7`, 即:

  1. 0 和 4 进行第一步归并,
  2. 0,4 和 2,6 进行第二步归并
  3. 0,4,2,6 和 1,3,5,7 进行最后一次归并.

原 x[1] 位置应该放上 4, x[3] 应该放上 6, 通过观察发现 1 与 4, 3 与 6 的二进制是互为镜像的: 001 与 100, 011 与 110.

extern float twiddle_table[];
void my_fft(
    int n_point, kiss_fft_cpx *in, kiss_fft_cpx *out, int with_twiddle_table) {
    int rev[n_point];
    for (int i = 0; i < n_point; i++) {
        rev[i] = 0;
    }
    /* NOTE: rev 用来保存 bit 前后翻转后的值, 然后根据 rev 对数据重排 */
    int bit = (int)log2(n_point);
    for (int i = 1; i < n_point; i++) {
        rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (bit - 1)));
    }
    memcpy(out, in, n_point * sizeof(in[0]));
    for (int i = 1; i < n_point; i++) {
        if (i < rev[i]) {
            kiss_fft_cpx tmp = out[i];
            out[i] = out[rev[i]];
            out[rev[i]] = tmp;
        }
    }
    /* NOTE: 下面的这段代码通常称为 butterfly (蝶形运算), 实际就是一个 DP */
    for (int mid = 1; mid < n_point; mid *= 2) {
        for (int j = 0; j < n_point; j += mid * 2) {
            for (int i = j; i < j + mid; i++) {
                kiss_fft_cpx even = out[i];
                kiss_fft_cpx odd = out[i + mid];

                int index = (int)log2(mid);
                float twiddle_factor_r = 0.0f;
                float twiddle_factor_i = 0.0f;
                if (with_twiddle_table) {
                    twiddle_factor_r = twiddle_table[index * N * 2 + i * 2];
                    twiddle_factor_i = twiddle_table[index * N * 2 + i * 2 + 1];
                } else {
                    twiddle_factor_r = cos(-1 * PI * i / mid);
                    twiddle_factor_i = sin(-1 * PI * i / mid);
                }
                /* NOTE: butterfly 主要体现在这一行代码 */                  
                out[i].r = even.r + odd.r * twiddle_factor_r -
                           odd.i * twiddle_factor_i;
                out[i].i = even.i + odd.r * twiddle_factor_i +
                           odd.i * twiddle_factor_r;
                if (with_twiddle_table) {
                    twiddle_factor_r =
                        twiddle_table[index * N * 2 + (i + mid) * 2];
                    twiddle_factor_i =
                        twiddle_table[index * N * 2 + (i + mid) * 2 + 1];
                } else {
                    twiddle_factor_r = cos(-1 * PI * (i + mid) / mid);
                    twiddle_factor_i = sin(-1 * PI * (i + mid) / mid);
                }
                /* NOTE: butterfly 主要体现在这一行代码 */
                out[i + mid].r = even.r + (odd.r * twiddle_factor_r -
                                           odd.i * twiddle_factor_i);
                out[i + mid].i = even.i + (odd.r * twiddle_factor_i +
                                           odd.i * twiddle_factor_r);
            }
        }
    }
}

1.3.5. 使用 twiddle table

\(even+factor*odd\) 使用的 factor 是一个常量, 可以提前计算出来, 叫做 twiddle table

产生 twiddle table:

void generate_twiddle_table(int n_point) {
    int bit = (int)log2(n_point);
    float twiddle_table[bit][n_point * 2];
    int32_t twiddle_table_fixed[bit][n_point * 2];
    for (int mid = 1; mid < n_point; mid *= 2) {
        for (int j = 0; j < n_point; j += mid * 2) {
            for (int i = j; i < j + mid; i++) {
                int index = (int)log2(mid);
                twiddle_table[index][i * 2] = cos(-1 * PI * i / mid);
                twiddle_table[index][i * 2 + 1] = sin(-1 * PI * i / mid);
                twiddle_table[index][(i + mid) * 2] =
                    cos(-1 * PI * (i + mid) / mid);
                twiddle_table[index][(i + mid) * 2 + 1] =
                    sin(-1 * PI * (i + mid) / mid);

                twiddle_table_fixed[index][i * 2] =
                    (int)(twiddle_table[index][i * 2] * (1 << 15));
                twiddle_table_fixed[index][i * 2 + 1] =
                    (int)(twiddle_table[index][i * 2 + 1] * (1 << 15));
                twiddle_table_fixed[index][(i + mid) * 2] =
                    (int)(twiddle_table[index][(i + mid) * 2] * (1 << 15));
                twiddle_table_fixed[index][(i + mid) * 2 + 1] =
                    (int)(twiddle_table[index][(i + mid) * 2 + 1] * (1 << 15));
            }
        }
    }
    FILE *fp = fopen("twiddle_table.c", "w");
    fprintf(fp, "float twiddle_table[%d*%d] = {\n", bit, n_point * 2);
    for (int i = 0; i < bit * n_point * 2; i++) {
        fprintf(fp, "%f,", ((float *)twiddle_table)[i]);
    }
    fprintf(fp, "};\n");

    fprintf(fp, "#include <stdint.h>\n");
    fprintf(fp, "int32_t twiddle_table_fixed[%d*%d] = {\n", bit, n_point * 2);
    for (int i = 0; i < bit * n_point * 2; i++) {
        fprintf(fp, "%d,", ((int32_t *)twiddle_table_fixed)[i]);
    }
    fprintf(fp, "};\n");
    fclose(fp);
}

1.3.6. 定点化

由于 fft 是线性变换, 所以定点化还是比较简单的. 为了避免对 sin, cos 定点化, 可以利用 twiddle table, 提前把 twiddle table 的结果转换为定点 (参考前面的 twiddle_table_fixed)

extern int32_t twiddle_table_fixed[];
typedef struct {
    int32_t r;
    int32_t i;
} kiss_fft_cpx_fixed;

static inline int32_t MUL_Q15(int32_t a, int32_t b) {
    return ((int64_t)a * (int64_t)b) >> 15;
}

void my_fft_fixed(int n_point, kiss_fft_cpx *in, kiss_fft_cpx *out) {
    int rev[n_point];
    for (int i = 0; i < n_point; i++) {
        rev[i] = 0;
    }
    int bit = (int)log2(n_point);
    for (int i = 1; i < n_point; i++) {
        rev[i] = (rev[i >> 1] >> 1 | ((i & 1) << (bit - 1)));
    }
    memcpy(out, in, n_point * sizeof(in[0]));
    for (int i = 1; i < n_point; i++) {
        if (i < rev[i]) {
            kiss_fft_cpx tmp = out[i];
            out[i] = out[rev[i]];
            out[rev[i]] = tmp;
        }
    }
    /* convert out to fixed point */
    kiss_fft_cpx_fixed *out_fixed =
        malloc(n_point * sizeof(kiss_fft_cpx_fixed));
    for (int i = 0; i < n_point; i++) {
        out_fixed[i].r = (int32_t)(out[i].r * (1 << 15));
        out_fixed[i].i = (int32_t)(out[i].i * (1 << 15));
    }

    for (int mid = 1; mid < n_point; mid *= 2) {
        for (int j = 0; j < n_point; j += mid * 2) {
            for (int i = j; i < j + mid; i++) {
                kiss_fft_cpx_fixed even = out_fixed[i];
                kiss_fft_cpx_fixed odd = out_fixed[i + mid];

                int index = (int)log2(mid);
                int32_t twiddle_factor_r = 0;
                int32_t twiddle_factor_i = 0;
                twiddle_factor_r = twiddle_table_fixed[index * N * 2 + i * 2];
                twiddle_factor_i =
                    twiddle_table_fixed[index * N * 2 + i * 2 + 1];

                out_fixed[i].r = even.r + MUL_Q15(odd.r, twiddle_factor_r) -
                                 MUL_Q15(odd.i, twiddle_factor_i);

                out_fixed[i].i = even.i + MUL_Q15(odd.r, twiddle_factor_i) +
                                 MUL_Q15(odd.i, twiddle_factor_r);

                twiddle_factor_r =
                    twiddle_table_fixed[index * N * 2 + (i + mid) * 2];
                twiddle_factor_i =
                    twiddle_table_fixed[index * N * 2 + (i + mid) * 2 + 1];

                out_fixed[i + mid].r = even.r +
                                       MUL_Q15(odd.r, twiddle_factor_r) -
                                       MUL_Q15(odd.i, twiddle_factor_i);

                out_fixed[i + mid].i = even.i +
                                       MUL_Q15(odd.r, twiddle_factor_i) +
                                       MUL_Q15(odd.i, twiddle_factor_r);
            }
        }
    }
    /* convert back to float */
    for (int i = 0; i < n_point; i++) {
        out[i].r = out_fixed[i].r >> 15;
        out[i].i = out_fixed[i].i >> 15;
    }
}

1.3.7. complex dft

TBD

1.3.8. kiss_fft

TBD

1.3.9. Q

  • 为什么 dft 的结果是复数

    \(X_k\) 表示成复数是为了把 \(cos(k)\) 和 \(sin(k)\) 对应的幅度写在一起. 另外使用 \(X_k\) 的其它运算例如 idft 使用复数运算写起来会简洁一些, 但本质上直接用 \(cos(k)\) 和 \(sin(k)\) 来写也是一样的.

  • 为什么 N 的点的 dft 的输出也是 N 个复数

    \(X_{k+N}=X_k\)

  • 为什么 \(X_{k}\) 的 \(\sin\) 部分是其真实值的负数

    只是一种写法, 用时记得取负号就可以. 更深层的原因是为了后续 idft 等计算方便, 例如 \((a+ib)*(\cos+i\cdot \sin)=(a\cdot \cos - b\cdot \sin)+i(ad+bc)\)

  • 为什么计算 \(X_{k}\) 时需要乘 \(\frac{2}{N}\) 而计算 idft 时是 \(\frac{1}{N}\)

Author: [email protected]
Date: 2022-09-29 Thu 22:01
Last updated: 2023-06-29 Thu 18:30

知识共享许可协议