拉格朗日插值

给定 $n$ 个点,我们可以确定唯一的最高 degree 为 $(n-1)$ 的多项式。

设多项式为 $f(x)$,第 $i$ 个点的坐标为 $(x_i,y_i)$,那么这个多项式在 $k$ 处的取值为:

$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-x_j}{x_i-x_j}$$

时间复杂度:$O(n^2)$

拉格朗日插值板子
const int maxn = 2005;

int n;
ll k;
ll x[maxn], y[maxn];

ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

ll solve(ll k) {
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        ll nu = 1, de = 1;
        for (int j = 1; j <= n; j++) {
            if (j == i) continue;
            nu = nu * (k - x[j] + mod) % mod;
            de = de * (x[i] - x[j] + mod) % mod;
        }
        de = qpow(de, mod-2);
        ans = (ans + y[i] * nu % mod * de % mod) % mod;
    }
    return ans;
}

int main() {
    cin >> n >> k;
    for (int i = 1; i <= n; i++) cin >> x[i] >> y[i];
    ll ans = solve(k);
    cout << ans << endl;
}

$x$ 坐标为连续整数的拉格朗日插值

如果 $x$ 坐标是连续的整数,那么我们可以在 $O(n)$ 求出 $f(k)$。

我们假设 $x_i = i$,那么:

$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-x_j}{x_i-x_j}$$

可以转化为:

$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-j}{i-j}$$

分类讨论 $j < i$ 和 $j > i$ 的情况:

$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{j=1}^{i-1} \frac{k-j}{i-j} \prod\limits_{j=i+1}^n \frac{k-j}{i-j}$$

所以可得:

$$f(k) = \sum\limits_{i=1}^ny_i \frac{1}{(i-1)!} (k-1)(k-2)…(k-(i-1)) \frac{1}{-1*-2*…*(-(n-i))} (k-i)(k-(i+1))…(k-n)$$

化简得到:

$$f(k) = \sum\limits_{i=1}^n \frac{y_i}{(-1)^{n-i} *i! * (n-i)!} \prod\limits_{j=1}^{i-1}(k-j) \prod\limits_{j=i+1}^n(k-j)$$

• 注意,后面这部分不能变成 $\frac{\prod\limits_{j=1}^{n}(k-j)}{k-i}$,因为我们无法确定 $k-i \neq 0$。

对于 $\prod\limits_{j=1}^{i-1}(k-j) \prod\limits_{j=i+1}^n(k-j)$,我们只要预处理出来一个前缀积和后缀积即可。

所以对于每一项(每个 $i$),我们都可以 $O(1)$ 时间算出对应值,总复杂度 $O(n)$。

重心拉格朗日插值

用于解决动态加点的问题。

利用重心拉格朗日插值,每加入一个新的点,我们可以在 $O(n)$ 时间内算出新的多项式。

$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j}^n \frac{k-x_j}{x_i-x_j}$$

我们设 $g(k) = \prod\limits_{i=1}^n(k-x_i)$,则有:

$$f(k) = g(k)\sum\limits_{i=1}^n \frac{1}{k-x_i} \prod\limits_{i\neq j}^n \frac{y_i}{x_i-x_j}$$

设 $t_i = \frac{y_i}{\prod\limits_{j \neq i}^n (x_i-x_j)}$,则有:

$$f(k) = g(k) \sum\limits_{i=1}^n \frac{t_i}{k-x_i}$$

所以每次添加一个新的点 $(x_{n+1}, y_{n+1})$ 时:

  1. 重新计算一下所有的 $t_i = t_i * \frac{1}{x_i - x_{n+1}}, i \in [1,n]$。
  2. 计算 $t_{n+1}$。
  3. 计算 $g(k)$。

• 注意,在求 $f(k)$ 时需要先判断一下是否有 $k=x_i$,有的话直接返回 $y_i$。(这个特判仅需要在重心拉格朗日插值中进行)。

重心拉格朗日插值板子
int Q;
ll x[maxn], y[maxn], t[maxn];
int n;

ll qpow(ll a, ll b) {
    ll res = 1;

    while (b) {
        if (b & 1)
            res = res * a % mod;

        a = a * a % mod;
        b >>= 1;
    }

    return res;
}

ll solve(ll k) {
    for (int i = 1; i <= n; i++) {
        if (x[i] == k)  // 需要特判是否 k = x[i]
            return y[i];
    }

    ll ans = 0;
    ll g = 1;
    ll sum = 0;

    for (int i = 1; i <= n; i++) {
        g = g * (k - x[i] + mod) % mod;
        sum = (sum + t[i] * qpow((k - x[i] + mod) % mod, mod - 2) % mod) % mod;
    }

    ans = g * sum % mod;
    return ans;
}

int main() {
    cin >> Q;
    while (Q--) {
        int op; cin >> op;

        if (op == 1) {  // 添加 (xx,yy) 的点
            ll xx, yy;
            cin >> xx >> yy;
            n++;
            x[n] = xx, y[n] = yy;
            ll de = 1;
            for (int j = 1; j < n; j++) {
                de = de * (x[n] - x[j] + mod) % mod;  // 更新 t[n]
                t[j] = t[j] * qpow((x[j] - x[n] + mod) % mod, mod - 2) % mod;
            }
            de = qpow(de, mod - 2);
            t[n] = y[n] * de % mod;
        } else {
            ll k; cin >> k;  // 求 f(k)
            cout << solve(k) << endl;
        }
    }
}

常用模型

  1. $\sum\limits_{i=1}^k i^n = 1^n+2^n+…k^n$ 是一个多项式 $f(k)$,其中 $deg(f)=n+1$(意味着需要 $n+2$ 个点进行插值)。

例题

例1 CF622F The Sum of the k-th Powers

题意

给定 $k,n$,求:

$$\sum\limits_{i=1}^k i^n = 1^n+2^n+…k^n$$

其中,$0 \leq k \leq 10^9, 1 \leq n \leq 10^6$。

题解

上面说了这是一个多项式 $f(k)$,其中 $deg(f)=n+1$,意味着需要 $n+2$ 个点进行插值。

又因为 $x_i = i, y_i = \sum\limits_{j=1}^i j^n$,我们直接按照上述 $O(n)$ 的连续整数来插值即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+5;

ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1)
            res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

ll k,n;
ll fac[maxn], inv_fac[maxn], t[maxn];
ll y[maxn], pre[maxn], suf[maxn];
int main() {
    fastio;
    cin >> k >> n;
    n += 2;

    fac[0] = inv_fac[0] = 1;
    pre[0] = suf[n+1] = 1;
    for (ll i = 1; i <= n; i++) {
        fac[i] = fac[i-1] * i % mod;
        y[i] = (y[i-1] + qpow(i, n-2)) % mod;
        pre[i] = pre[i-1] * (k - i + mod) % mod;
    }
    for (ll i = n; i >= 1; i--) suf[i] = suf[i+1] * (k - i + mod) % mod;
    inv_fac[n] = qpow(fac[n], mod-2);
    for (ll i = n-1; i >= 1; i--) {
        inv_fac[i] = inv_fac[i+1] * (i+1) % mod;
    }

    ll ans = 0;
    for (ll i = 1; i <= n; i++) {
        int flag = ((n-i) & 1 ? -1 : 1);
        ll res = y[i] % mod * inv_fac[i-1] % mod * inv_fac[n-i] % mod * pre[i-1] % mod * suf[i+1] % mod;
        res = (res * flag + mod) % mod;
        ans = (ans + res) % mod;
    }
    cout << ans << endl;
}

例2 CF1731F. Function Sum

题意

如果我们有一个正整数数组 $a_1,a_2…,a_n$,定义 $ls(i)$ 为:$i$ 的左边比它小的数量,$gr(i)$ 为:$i$ 的右边比它大的数量。

如果对于index $i$,$ls(i) < gr(i)$,说明 $i$ 是good的。

定义 $f(a)$($a$ 是一个数组 $a_1,a_2…,a_n$)的值为:所有的 good $i$ 的对应 $a_i$ 的和。


现给定 $n, k$,求所有长度为 $n$,保证 $\forall i \in [1,n], a_i \in [1,k]$ 的所有可能的数组 $a$,$f(a)$ 的和。

其中,$n \in [1, 50], k \in [2, 998244353)$,答案对 $998244353$ 取模。

题解

典中典之:看见 “所有满足xx条件” 就想到贡献。

考虑每一个位置 $i$ 的每一个可能的值 $j$ 对答案会贡献多少?

我们设 $f(i,j)$ 为:如果 $a_i = j$,那么满足 $ls(i) < gr(i)$ 的数组数量。

则 $j$ 这个数的贡献就是 $j * f(i,j)$。

形式化的,我们最终要求的答案就是

$$\sum\limits_{i=1}^n \sum\limits_{j=1}^k j * f(i,j)$$


怎么求 $f(i,j)$?

我们枚举 $ls(i)$ 和 $gr(i)$ 的值,设 $k_1 = ls(i), k_2 = gr(i)$,有:

$$f(i,j) = \sum\limits_{k_1=0}^{i-1} \sum\limits_{k_2 = k_1+1}^{n-i} C_{i-1}^{k_1} (j-1)^{k_1}(k-j+1)^{i-1-k_1} C_{n-i}^{k_2}(k-j)^{k_2}(j)^{n-i-k_2}$$

简单解释一下:

$C_{i-1}^{k_1} (j-1)^{k_1}(k-j+1)^{i-1-k_1}$:其中 $C_{i-1}^{k_1}$ 代表从 $i$ 左边元素里选出 $k_1$ 个比它小的,而这 $k_1$ 个元素取的值只要 $<j$ 就可以了,所以有 $(j-1)^{k_1}$ 种,而剩下的 $(i-1-k_1)$ 个元素则只要满足 $\geq j$ 即可,所以有 $(k-j+1)^{i-1-k_1}$ 种。

$C_{n-i}^{k_2}(k-j)^{k_2}(j)^{n-i-k_2}$ 同理。


注意到 $n \leq 50$,所以枚举 $i$ 是没问题,但是 $k$ 非常大所以不能枚举。

但是,我们发现如果固定了 $i$,那么就只有一个变量 $j$ 了。

看一下 $\sum\limits_{j=1}^k j * f(i,j)$ 中,$j$ 的最高系数是 $1 + k_1 + (i-1-k_1) + k_2 + (n-i-k_2) = n$。

所以 $\sum\limits_{j=1}^k j * f(i,j)$ 是一个多项式,参考上面的常用模型,$\sum\limits_{j=1}^k a_j j^n$ 是一个 $deg = n+1$ 的多项式,需要 $(n+2)$ 个点来插值。

插值的点,我们就选 $j’ = 1…55$,$x_{j’} = j'$,而 $y_{j’} = \sum\limits_{j'=1}^j f(i,j’)$。插值结束后可以得到一个多项式 $h(x)$,直接求 $h(k)$ 的值即可。


如果 $k \leq 55$ 使得插值的点数量不够呢?那对于较小的 $k$ 我们直接暴力计算就可以了。

代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;

template<class T>
T qpow(T a, int b) {
    T res = 1;
    while (b) {
        if (b & 1) res *= a;
        a *= a;
        b >>= 1;
    }
    return res;
}
int norm(int x) {
    if (x < 0) {
        x += mod;
    }
    if (x >= mod) {
        x -= mod;
    }
    return x;
}
struct Z {
    int x;
    Z(int x = 0) : x(norm(x)) {}
    Z(ll x) : x(norm(x % mod)) {}
    int val() const {
        return x;
    }
    Z operator-() const {
        return Z(norm(mod - x));
    }
    Z inv() const {
        assert(x != 0);
        return qpow(*this, mod - 2);
    }
    Z &operator*=(const Z &rhs) {
        x = (ll)(x) * rhs.x % mod;
        return *this;
    }
    Z &operator+=(const Z &rhs) {
        x = norm(x + rhs.x);
        return *this;
    }
    Z &operator-=(const Z &rhs) {
        x = norm(x - rhs.x);
        return *this;
    }
    Z &operator/=(const Z &rhs) {
        return *this *= rhs.inv();
    }
    friend Z operator*(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res *= rhs;
        return res;
    }
    friend Z operator+(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res += rhs;
        return res;
    }
    friend Z operator-(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res -= rhs;
        return res;
    }
    friend Z operator/(const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res /= rhs;
        return res;
    }
    friend std::istream &operator>>(std::istream &is, Z &a) {
        ll v;
        is >> v;
        a = Z(v);
        return is;
    }
    friend std::ostream &operator<<(std::ostream &os, const Z &a) {
        return os << a.val();
    }
};

const int maxn = 1e4+5;
const int maxm = 1e5+5;

const int M = 1e4;
int n, k;
Z x[maxn], y[maxn];
Z solve(Z k) {
    Z ans = 0;
    for (int i = 1; i <= 55; i++) {
        Z nu = 1, de = 1;
        for (int j = 1; j <= 55; j++) {
            if (j == i) continue;
            nu = nu * (k - x[j]);
            de = de * (x[i] - x[j]);
        }
        de = de.inv();
        ans = ans + y[i] * nu * de;
    }
    return ans;
}

Z fac[maxm], inv_fac[maxm];
Z C(Z i, Z j) {
    return fac[i.val()] * inv_fac[j.val()] * inv_fac[(i-j).val()];
}

// brute force calculate f(i, j)
Z cal(Z i, Z j) {
    Z ans = 0;
    for (Z k1 = 0; k1.val() <= (i-1).val(); k1 = k1+1) {
        for (Z k2 = k1+1; k2.val() <= (n-i).val(); k2 = k2+1) {
            ans = ans + C(i-1, k1) * qpow(j-1, k1.val()) * qpow(k-j+1, (i-1-k1).val()) * C(n-i, k2) * qpow(k-j, k2.val()) * qpow(j, (n-i-k2).val());
        }
    }
    return ans;
}

int main() {
    cin >> n >> k;
    fac[0] = 1;
    for (int i = 1; i <= M; i++) fac[i] = fac[i-1] * i;
    inv_fac[M] = fac[M].inv();
    for (int i = M-1; i >= 0; i--) inv_fac[i] = inv_fac[i+1] * (i+1);

    if (k <= 55) {
        Z ans = 0;
        for (Z i = 1; i.val() <= n; i = i + 1) {
            for (Z j = 1; j.val() <= k; j = j + 1) {
                ans = ans + (j * cal(i, j));
            }
        }
        cout << ans.val() << endl;
    } else {
        Z ans = 0;
        for (Z i = 1; i.val() <= n; i = i + 1) {
            // 插值 f(i,j)
            for (int j = 1; j <= 55; j++) {
                x[j] = j;
                y[j] = y[j-1] + cal(i, j) * j;
            }
            ans += solve(k);
        }
        cout << ans << endl;
    }
}