模版

FFT
const int maxn = (1<<22) + 5;  // 注意这里需要是 > 2^k

struct Complex {
    double x,y;
    Complex(double _x = 0.0, double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 1; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 1; i < n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;

// calculate h(x) = f(x) * g(x), n1 = deg(f) + 1, n2 = deg(g) + 1
void poly_multiply(const double f[], int n1, const double g[], int n2, double h[]) {
    int n = 1;
    n1--, n2--;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn];
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = a[i].x;
}

int n1, n2;
double f[maxn/2], g[maxn/2], h[maxn];

int main() {
    cin >> n1 >> n2;  // deg(f) = n1 - 1, deg(g) = n2 - 1
    for (int i = 0; i < n1; i++) cin >> f[i];
    for (int i = 0; i < n2; i++) cin >> g[i];
    poly_multiply(f, n1, g, n2, h);
    for (int i = 0; i <= n1+n2-2; i++) cout << h[i] << " ";
}

使用板子的注意事项:

  1. maxn 的值必须大于一个 $2^k$,且 $2^k > n_1+n_2$

介绍

FFT (Fast Fourier Transform, 快速傅立叶变换) 可以在 $O(n \log n)$ 的时间内计算两个多项式的乘法,也可以用来计算卷积。

• 一般卷积的形式为 $$C_k = \sum\limits_{i=0}^kA_iB_{k-i} = A_0B_k + A_1B_{k-1} + … + A_kB_0$$

在开始介绍 FFT 之前,我们需要一些基本的前置知识。

多项式

一个多项式有两种表示方法:系数表示法点值表示法

系数表示法

$$f(x) = a_0 + a_1x + a_2x^2 + … + a_nx^n$$

则 $\{a_0,a_1,…,a_n\}$ 即可表示这个多项式 $f(x)$。

点值表示法

选取多项式函数上的 $(n+1)$ 个点 $(x_0,y_0), (x_1,y_1), …, (x_n,y_n)$ 也可以唯一的表示这个多项式 $f(x)$。

证明:使用高斯消元即可,$(n+1)$ 个未知量,对应 $(n+1)$ 个方程。

复数

$n$ 次单位根

若 $\omega$ 满足 $\omega^n = 1$,那么 $\omega$ 就是 $n$ 次单位根 (n-th root of unity)。

若 $\omega$ 满足 $\omega^n = 1$,且满足 $\omega ^ k \neq 1, \forall k \in [1, n-1]$,则 $\omega$ 就是 $n$ 次本原单位根 (primitive n-th root of unity)

$\omega$ 的取值

定理

$$\omega_n^k = e^{i\frac{2k\pi}{n}} = \cos(\frac{2k\pi}{n}) + i \sin(\frac{2k\pi}{n})$$

$$\omega = e^{i \frac{2\pi}{n}} = \cos(\frac{2\pi}{n}) + i \sin(\frac{2\pi}{n})$$

为了方便,下文我们就写 $\omega = e^{i \frac{2\pi}{n}}$

$\omega$ 有几个优秀的性质:

  1. $\omega$ 为 $n$ 次本原单位根,这说明 $\forall k \in [1, n-1], \omega^n = 1$,且 $\omega ^ k \neq 1$
  2. $\omega^{k+\frac{n}{2}} = -\omega^{k}$

FFT原理

问题

给定两个已知多项式 $f(x)$ 和 $g(x)$,$deg(f) = n_1, deg(g) = n_2$,求 $h(x) = f(x) * g(x)$?

FFT 分为三步:

  1. DFT:令 $n = n_1 + n_2$,在 $x$ 轴上选择 $(n + 1)$ 个特殊点 $x_0,x_1,…,x_{n}$,求出 $f(x_0), f(x_1), …, f(x_n)$ 与 $g(x_0), g(x_1), …, g(x_n)$ 的值。

    复杂度:$O(n\log n)$

  2. 点值乘法:求出 $(n+1)$ 个点上,点值的乘积,即 $h(x_i) = f(x_i) * g(x_i), i \in [0, n]$

    复杂度:$O(n)$

  3. IDFT:给定一个未知 $h(x)$ 上的 $(n+1)$ 个特殊点 $\{(x_0,y_0), (x_1,y_1), …, (x_n,y_n)\}$,求出 $h(x)$ 的系数表达式。

    复杂度:$O(n\log n)$

第一步 DFT

问题

假设现在有一个 $deg(f) = n$ 的多项式,且 $n$ 为偶数。我们想要求 $f(x)$ 在 $n$ 个不同的特殊点 $x_0,x_1,…,x_{n-1}$ 上的值。

注意到这个问题也等价于一个矩阵乘法,即给定 一个矩阵 $X$ 和一个向量 $\vec a$,在我们可以自由选择 $x_0,x_1,…,x_{n-1}$ 的情况下,用 $O(n \log n)$ 的时间求出 $\vec y = X \vec a$ 的值?

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix} \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} = \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} $$

那么,如何快速求出 $f(x_0), f(x_1), …, f(x_{n-1})$呢?我们只要选一些特殊的 $x_0, x_1, …, x_{n-1}$ 即可。

结论:$x_i = \omega^i$


证明:

因为

$$f(x) = a_0 + a_1x + a_2x^2 + … + a_nx^n$$

$$=(a_0 + a_2x^2 + a_4x^4 + … + a_{n-2}x^{n-2}) + (a_1x + a_3x^3 + a_5x^5 + … + a_{n-1}x^{n-1})$$

$$f_{even}(x) = a_0 + a_2x + a_4x^2 + … + a_{n-2}x^{\frac{n}{2} - 1}$$ $$f_{odd}(x) = a_1 + a_3x + a_5x^2 + … + a_{n-1}x^{\frac{n}{2} - 1}$$

则有

$$f(x) = f_{even}(x^2) + xf_{odd}(x^2)$$

现在,我们选择 $x_i = \omega^i$,则我们求 $f(x_i)$ 和 $f(x_{i+\frac{n}{2}})$,可以发现

$$x_{\frac{n}{2}} = \omega^{\frac{n}{2}} = -1, ~~x_{i+\frac{n}{2}} = \omega^{i+\frac{n}{2}} = -\omega^i = -x_i$$

所以

$$\forall i \in [0,\frac{n}{2}-1], ~~f(x_i) = f_{even}(x_i^2) + x_if_{odd}(x_i^2), ~~f(x_{i+\frac{n}{2}}) = f_{even}(x_i^2) - x_if_{odd}(x_i^2)$$

又因为,求 $\forall i \in [0,n], f(x_i)$ 的值,等价于求 $\forall i \in [0,\frac{n}{2}-1], f(x_i) 和 f(x_{i+\frac{n}{2}})$ 的值,

所以这个问题就可以等价转化为求 $f_{even}(x_i^2) + x_if_{odd}(x_i^2)$ 的值。而这就是一个规模减半了的子问题。

所以设 $T(n)$ 为:求出 $\forall i \in [0,n], f(x_i)$ 的值所需的时间,则可以列出方程:

$$T(n) = 2T(\frac{n}{2}) + O(n)$$

所以 $T(n) = O(n\log n)$

第二步 点值乘法

第一步DFT以后,我们有了 $f(x_0), f(x_1), …, f(x_{n-1})$ 与 $g(x_0), g(x_1), …, g(x_{n-1})$ 的值,我们就可以直接将这些值乘起来,得到

$$h(x_0), h(x_1), …, h(x_{n-1})$$

的值。

第三步 IDFT

问题

现在给定一个未知的多项式 $h(x)$,且 $deg(h) = n-1$,已知 $h(x)$ 上的 $n$ 个点 $(x_0,h(x_0)), (x_1,h(x_1)), …, (x_{n-1},h(x_{n-1}))$,求 $h(x)$ 的表达式?

注意到这个问题也等价于一个矩阵乘法,即给定 一个矩阵 $X$ 和一个向量 $\vec y$,已知 $X\vec a = \vec y$,如何用 $O(n \log n)$ 的时间求出 $\vec a$ 的值?

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix} \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} = \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} $$

发现了什么?这实际上和第一步 DFT 的过程是一样的,如果我们把 $X^{-1}$ 给两边乘上 ,就变成了

$$\begin{bmatrix} 1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix}^{-1} \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix}$$

那么,令

$$X=\begin{bmatrix}1 & x_0 & x_0^2 & … & x_0^{n-1} \\
1 & x_1 & x_1^2 & … & x_1^{n-1} \\
…\\
1 & x_{n-1} & x_{n-1}^2 & … & x_{n-1}^{n-1} \\
\end{bmatrix}$$

把 $x_i = \omega^i$ 代进去,就可以得到:

$$X = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^1 & \omega^{1\cdot 2} & … & \omega^{1\cdot (n-1)} \\
…\\
1 & \omega^{1} & \omega^{(n-1)\cdot 2} & … & \omega^{(n-1)\cdot (n-1)} \\
\end{bmatrix}$$

$$V(\omega) = X = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^1 & \omega^{1\cdot 2} & … & \omega^{1\cdot (n-1)} \\
…\\
1 & \omega^{1} & \omega^{(n-1)\cdot 2} & … & \omega^{(n-1)\cdot (n-1)} \\
\end{bmatrix}, ~~~V(\omega^{-1}) = \begin{bmatrix}1 & 1 & 1 & … & 1 \\
1 & \omega^{-1} & \omega^{-1\cdot 2} & … & \omega^{-1\cdot (n-1)} \\
…\\
1 & \omega^{-1} & \omega^{-(n-1)\cdot 2} & … & \omega^{-(n-1)\cdot (n-1)} \\
\end{bmatrix}$$

则有:

$$V(\omega) \cdot V(\omega^{-1}) = nI$$

证明:直接进行矩阵运算即可,注意要用到 $\omega^n = 1$ 的特性。

这说明

$$X^{-1} = V(w)^{-1} = \frac{1}{n} V(\omega^{-1})$$

所以上面的 $X^{-1}\vec y = \vec a$ 就可以表示为:

$$V(\omega)^{-1} \begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix} ~~\Rightarrow~~ \frac{1}{n}V(\omega^{-1})\begin{bmatrix} y_0 \\
y_1 \\
…\\
y_{n-1} \\
\end{bmatrix} = \begin{bmatrix} a_0 \\
a_1 \\
…\\
a_{n-1} \\
\end{bmatrix}$$

这就等价于我们有一个系数为 $y_0, y_1, … y_{n}$ 的多项式,然后要在 $x_0^{-1}, x_1^{-1}, …, x_n^{-1}$ 这些特殊点上进行求值。

我们发现 $x_0^{-1} = \omega^{-1}$ 仍然是一个 $n$ 次本原单位根 (primitive n-th root of unity)(证明略),所以它仍然有上述优秀的性质。

所以这个问题可以用 DFT 解决,求出来的值,就是我们想要的 $h(x)$ 的多项式系数。

注意一下,上述所有过程都是完美的分治过程,所以我们需要把 $deg(h)$ 补成一个 $2^k$ 形式,其中 $deg(h) = 2^k \geq 2(n_1+n_2)$

非递归 FFT

递归 FFT 常数过大,不好用,所以我们可以写一个非递归版本的。

我们手推一下 FFT 的分治过程,每一次分治都是拆分奇数项和偶数项,模拟拆分过程有:

初始序列:$\{a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7\}$

第一次拆分:$\{a_0,a_2,a_4,a_6\},\{a_1,a_3,a_5,a_7\}$

第二次拆分:$\{a_0,a_4\},\{a_2,a_6\},\{a_1,a_5\},\{a_3,a_7\}$

第三次拆分:$\{a_0\},\{a_4\},\{a_2\},\{a_6\},\{a_1\},\{a_5\},\{a_3\},\{a_7\}$

发现了什么?

拆分前:0, 1, 2, 3, 4, 5, 6, 7

拆分后:0, 4, 2, 6, 1, 5, 3, 7

化成二进制:

拆分前:000, 001, 010, 011, 100, 101, 110, 111

拆分后:000, 100, 010, 110, 001, 101, 011, 111

机智的你或许发现了,拆分前后,每个数字对应的二进制刚好就是翻转了一下,最左边的bit跑到最右边去了,反之亦然。

那么这个二进制翻转就可以直接 $O(n \log n)$ 实现了,但是这样不够高效,有一个 $O(n)$ 的基于 DP 的方法:

设 $R(x)$ 为 $x$ 翻转后的结果,可知 $R(0) = 0$。

我们从小到大求 $R(x)$,在求 $R(x)$ 时,我们已知了 $R(\frac{x}{2})$ 的值,相当于:

$$x = abc[0/1], ~ \frac{x}{2} = abc, ~ R(\frac{x}{2}) = cba, ~ R(x) = [0/1]cba$$

则我们只要判断一下 $x$ 的最低位为 $0$ 还是 $1$,然后给 $R(\frac{x}{2})$ 的最高位补上一个 $0$ 或者 $1$ 就可以得到 $R(x)$ 了。

void rearrange(Complex a[], const int n) {
    static int rev[maxn];  
    for (int i = 0; i <= n; i++) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) rev[i] |= (n >> 1);
    }
    for (int i = 0; i <= n; i++) {
        if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
    }
}

至于非递归写法的原理,我懒得写了先鸽着吧(

优化(三步变两步)

暂时没遇到需要优化的情况,先鸽着。

例题

例1 洛谷P1919 【模板】A*B Problem升级版(FFT快速傅里叶)

题意

给定正整数 $a,b$,求 $a*b$?

其中,$1\leq a,b \leq 10^{10^6}$

题解

把大整数看成多项式,例如 $356$ 就看作 $f(x) = 3x^2 + 5x + 6$

两个大整数相乘就看作多项式相乘,最后代入 $x = 10$,处理一下进位和前缀零即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = (1<<21) + 5;

struct Complex {
    long double x,y;
    Complex(long double _x = 0.0, long double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;


// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(int f[], int n1, int g[], int n2, int h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn], c[maxn];
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = (int)(a[i].x + 0.5);

}

int n1, n2;
int f[maxn/2], g[maxn/2], h[maxn];

int main() {

    string s1, s2;
    cin >> s1 >> s2;
    int n1 = s1.size() - 1, n2 = s2.size() - 1;
    for (int i = 0; i <= n1; i++) f[i] = s1[n1-i] - '0';
    for (int i = 0; i <= n2; i++) g[i] = s2[n2-i] - '0';
    poly_multiply(f, n1, g, n2, h);

    int i = 0;
    for (i = 0; i < maxn-2; i++) {
        h[i+1] += h[i] / 10;
        h[i] %= 10;
    }
    string ans = "";
    for (int i = 0; i < maxn-2; i++) ans += (char)(h[i] + '0');

    for (int j = ans.size()-1; j >= 0; j--) {
        if (ans[j] == '0') {
            ans.erase(ans.begin() + j);
        } else break;
    }

    reverse(ans.begin(), ans.end());
    cout << ans << endl;
}

例2 洛谷P3338 [ZJOI2014]力

题意

给定 $n$ 个实数 $q_1,q_2,…,q_n$,定义:

$$F_j = \sum\limits_{i=1}^{j-1}\frac{q_i \times q_j}{(i-j)^2} - \sum\limits_{i=j+1}^{n}\frac{q_i \times q_j}{(i-j)^2}$$

$$E_j = \frac{F_j}{q_j}$$

求 $E_1, E_2, …, E_n$ 的值?

其中,$n \leq 10^5$

题解

FFT 可以用于处理卷积问题,我们先化简一下式子:

$$E_j = \frac{F_j}{q_j} = \sum\limits_{i=1}^{j-1}\frac{q_i}{(i-j)^2} - \sum\limits_{i=j+1}^{n}\frac{q_i}{(i-j)^2}$$

$$f_i = q_i, ~~g_i = \frac{1}{i^2}$$

$$E_j = \sum\limits_{i=1}^{j-1}f_ig_{j-i} - \sum\limits_{i=j+1}^{n}f_ig_{i-j}$$

注意到第一项 $\sum\limits_{i=1}^{j-1}f_ig_{j-i}$ 可以写成 $i=0$ 开始,到 $j$ 结束(令 $f_0 = g_0 = 0$ 即可),所以第一项等于 $\sum\limits_{i=0}^{j}f_ig_{j-i}$,也就是一个卷积形式,可以利用 FFT 计算。


第二项 $\sum\limits_{i=j+1}^{n}f_ig_{i-j}$ 怎么写成卷积呢?

注意到,如果我们令 $f_j’ = f_{n-j}$,则有

$$\sum\limits_{i=j+1}^{n}f_ig_{i-j} = \sum\limits_{i=j}^{n}f’_{n-i}g_{i-j} = g_0f’_{n-j} + g_1f’_{n-j-1} + … + g_{n-j}f’_0$$

那么,这就又是一个卷积了,所以最后我们的结果可以表示为:

$$E_j = \sum\limits_{i=0}^{j}f_ig_{j-i} - \sum\limits_{i=j}^{n}f’_{n-i}g_{i-j}$$


第一项 $\sum\limits_{i=0}^{j}f_ig_{j-i}$:计算多项式 $f(x) * g(x)$,取第 $j$ 项的系数即可。

第二项 $\sum\limits_{i=j}^{n}f’_{n-i}g_{i-j}$:计算多项式 $f’(x) * g(x)$,取第 $n-j$ 项的系数即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = (1<<21) + 5;

struct Complex {
    double x,y;
    Complex(double _x = 0.0, double _y = 0.0) {
        x = _x; y = _y;
    }
    Complex operator+(const Complex& b) const {
        return Complex(x + b.x, y + b.y);
    }
    Complex operator-(const Complex& b) const {
        return Complex(x - b.x, y - b.y);
    }
    Complex operator*(const Complex &b) const {
        return Complex(x * b.x - y * b.y, x * b.y + y * b.x);
    }
    Complex operator/(const long double& b) const {
        return Complex(x/b, y/b);
    }
};

struct FastFourierTransform {
    void rearrange(Complex a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void fft(Complex a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            Complex wn(cos(2*pi / k), on * sin(2*pi / k));
            for (int i = 0; i < n; i += k) {
                Complex w(1, 0);
                for (int j = i; j < i + (k>>1); j++) {
                    Complex x = a[j], y = w * a[j+(k>>1)];
                    a[j] = x + y;
                    a[j+(k>>1)] = x - y;
                    w = w * wn;
                }
            }
        }
        if (on == -1) {
            for (int i = 0; i < n; i++) a[i] = a[i] / n;
        }
    }
} fft;

// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(const double f[], int n1, const double g[], int n2, double h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    static Complex a[maxn], b[maxn];
    memset(a, 0, sizeof(a));
    memset(b, 0, sizeof(b));
    for (int i = 0; i <= n1; i++) a[i] = Complex(f[i], 0);
    for (int i = 0; i <= n2; i++) b[i] = Complex(g[i], 0);
    fft.fft(a, n, 1);  // 注意这里用的是 n (不是 n1)
    fft.fft(b, n, 1);
    for (int i = 0; i <= n; i++) a[i] = a[i] * b[i];
    fft.fft(a, n, -1);
    for (int i = 0; i <= n; i++) h[i] = a[i].x;
}

double f[maxn/2], g[maxn/2], h[maxn];
double f2[maxn], h2[maxn];

int main() {
    int n; cin >> n;
    for (int i = 1; i <= n; i++) {
        cin >> f[i];
        g[i] = (1.0/i/i);
    }

    poly_multiply(f, n, g, n, h);
    for (int i = 0; i <= n; i++) f2[i] = f[n-i];  // f'[i] = f[n-i]
    poly_multiply(f2, n, g, n, h2);

    for (int i = 1; i <= n; i++) {
        printf("%.3f\n", h[i] - h2[n-i]);
    }
}

例3 洛谷P3723 [AH2017/HNOI2017]礼物

题意

给定两个长度为 $n$ 的正整数序列 $x,y$,我们进行一次以下操作:

第一步:给其中一个序列的所有数加上一个非负整数 $c$。

第二步:旋转一个序列,可以旋转任意长度。

序列的旋转意思是全部元素左移,或者右移,超出界限的部分从另外一边补齐。

例如: $[1,2,3,4,5]$ 向右旋转 $2$ 个长度,得到 $[4,5,1,2,3]$。

现在我们想知道,如何选择 $c$ 和旋转序列的方式,使得

$$\sum\limits_{i=1}^n(x_i-y_i)^2$$

最小?

其中,$1\leq n \leq 50000, 1\leq x_i, y_i \leq 100$

题解

首先由于第一步中,我们可以选择任意一个序列来加上 $c$,所以可以看作选择一个序列,然后加上一个 任意整数 $c$

最后我们设旋转后,对齐的两个序列为 $a,b$。

$$\sum\limits_{i=1}^n(x_i-y_i)^2$$

可以表示为:

$$\sum\limits_{i=1}^n(a_i+c-b_i)^2$$

$$=\sum\limits_{i=1}^n ((a_i^2+b_i^2+c^2)+2c(a_i-b_i) - 2a_ib_i)$$

$$=nc^2 + \sum\limits_{i=1}^n (a_i^2+b_i^2)+2c\sum\limits_{i=1}^n(a_i-b_i) - \sum\limits_{i=1}^n 2a_ib_i$$

如果我们暂时不考虑 $c$,则会发现,无论最终旋转的 $a,b$ 是什么样的,前三项均为定值,只有 $\sum\limits_{i=1}^n 2a_ib_i$ 的值是个变量。

所以我们只需要 最大化 $\sum\limits_{i=1}^n a_ib_i$ 的值即可。


现在问题转化为,如何求一种旋转方式,使得 $\sum\limits_{i=1}^n a_ib_i$ 的值最大?

我们不妨直接把每一种旋转方式的值都求出来,然后取最大即可。

怎么全部一次性求出呢?卷积

$$E_j = \sum\limits_{i=j}^na_ib_{i-j+1} + \sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j}$$

则 $E_j$ 所代表的就是,如果我们如此排列:

$a_j$ $a_{j+1}$ $a_n$ $a_1$ $a_2$ $a_{j-1}$
$b_1$ $b_2$ $b_{n+1-j}$ $b_{n+2-j}$ $b_{n+3-j}$ $b_{n}$

最后得出来对应位置的乘积的和,就是 $E_j$。

所以我们只要求出 $E_1, E_2, …, E_n$,然后取最小值即可。

如何求呢?


先考虑第一项 $\sum\limits_{i=j}^na_ib_{i-j+1}$

把它转成卷积的形式,令 $c_i = a_{n-i}$,则有:

$$\sum\limits_{i=j}^na_ib_{i-j+1} = \sum\limits_{i=j}^nc_{n-i}b_{i-j+1} = b_1c_{n-j} + b_2c_{n-j-1} + … + b_{n-j+1}c_{0}$$

令 $a_0 = b_0 = 0$,就可以再给上式加一个 $b_0c_{n-j+1}$,得到:

$$b_0c_{n-j+1} + b_1c_{n-j} + b_2c_{n-j-1} + … + b_{n-j+1}c_{0}$$

那么这就是一个卷积了,相当于 $b(x) * c(x)$ 的第 $(n-j+1)$ 项。


再考虑第二项 $\sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j}$

仍然转成卷积形式,令 $d_i = b_{n-i}$,则有:

$$\sum\limits_{i=1}^{j-1}a_ib_{n+1+i-j} = \sum\limits_{i=1}^{j-1}a_id_{j-i-1} = a_1d_{j-2} + a_2d_{j-3} + … + a_{j-1}d_{0}$$

由于 $a_0 = 0$,给上式加上 $a_0d_{j-1}$,得到:

$$a_0d_{j-1} + a_1d_{j-2} + a_2d_{j-3} + … + a_{j-1}d_{0}$$

这又是一个卷积,相当于 $a(x) * d(x)$ 的第 $(j-1)$ 项。


所以 $\sum\limits_{i=1}^n a_ib_i$ 的最大值求出来了,剩下的变量只有 $c$ 了。

注意到 $1 \leq x_i,y_i \leq 100$,所以 $c$ 只要取遍所有的可能值,然后判断最小值即可。

在代码中,$c \in [-200, 200]$。

代码
#include <bits/stdc++.h>
using namespace std;

const int mod = 998244353;
const int maxn = (1<<20) + 5;

struct NTT {
    const ll g = 3, invg = inv(g);  // mod = 998244353
    inline ll qpow(ll a, ll b) {
        ll res = 1;
        while (b) {
            if (b & 1) res = res * a % mod;
            a = a * a % mod;
            b >>= 1;
        }
        return res;
    }
    inline ll inv(ll a) {
        return qpow(a, mod-2);
    }

    void rearrange(ll a[], const int n) {
        static int rev[maxn];  // maxn > deg(h) 且 maxn 为 2的k次方 + 5
        for (int i = 0; i <= n; i++) {
            rev[i] = rev[i >> 1] >> 1;
            if (i & 1) rev[i] |= (n >> 1);
        }
        for (int i = 0; i <= n; i++) {
            if (i < rev[i]) swap(a[i], a[rev[i]]);  // 保证每对数字只翻转一次
        }
    }

    void ntt(ll a[], const int n, int on) {
        rearrange(a, n);
        for (int k = 2; k <= n; k <<= 1) {   // 模拟分治的合并过程
            ll wn = qpow(on == 1 ? g : invg, (mod-1)/k);
            for (int i = 0; i < n; i += k) {
                ll w = 1;
                for (int j = i; j < i + (k>>1); j++) {
                    ll x = a[j], y = w * a[j+(k>>1)] % mod;
                    a[j] = (x + y) % mod;
                    a[j+(k>>1)] = (x - y + mod) % mod;
                    w = w * wn % mod;
                }
            }
        }
        if (on == -1) {
            ll invn = inv(n);
            for (int i = 0; i < n; i++) a[i] = a[i] * invn % mod;
        }
    }
} ntt;

// calculate h(x) = f(x) * g(x), deg(f) = n1, deg(g) = n2
void poly_multiply(ll f[], int n1, ll g[], int n2, ll h[]) {
    int n = 1;
    while (n <= n1 + n2) n <<= 1;  // deg(h) = n1 + n2
    memset(h, 0, sizeof(ll) * n);
    ntt.ntt(f, n, 1);  // 注意这里用的是 n (不是 n1)
    ntt.ntt(g, n, 1);
    for (int i = 0; i <= n; i++) h[i] = f[i] * g[i] % mod;
    ntt.ntt(h, n, -1);
}

int n1, n2;
ll a[maxn/2], b[maxn/2], c[maxn/2], d[maxn/2], r1[maxn], r2[maxn];

int n,m;
ll ans = 0;

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i], c[n-i] = a[i];
    for (int i = 1; i <= n; i++) cin >> b[i], d[n-i] = b[i];
    ll absum = 0;  // a_i - b_i
    for (int i = 1; i <= n; i++) absum += a[i] - b[i];
    ll res = 1e18;
    for (ll c = -200; c <= 200; c++) {
        res = min(res, n * c * c + 2LL * c * absum);
    }

    ans += res;

    for (int i = 1; i <= n; i++) {
        ans += (a[i] * a[i]) + (b[i] * b[i]);
    }

    poly_multiply(a, n, d, n, r1);  // r1 = a*d
    poly_multiply(b, n, c, n, r2);  // r2 = b*c

    ll ab = 0;
    for (int i = 1; i <= n; i++) {
        ab = max(ab, (r1[i-1] + r2[n-i+1]));
    }

    ans -= 2LL * ab;

    cout << ans << endl;
}

参考资料

  1. HKU COMP3250 FFT课件
  2. https://oi-wiki.org/math/poly/fft/
  3. https://oi.men.ci/fft-notes/
  4. https://zhuanlan.zhihu.com/p/128661674