拉格朗日插值
Contents
拉格朗日插值
给定 $n$ 个点,我们可以确定唯一的最高 degree 为 $(n-1)$ 的多项式。
设多项式为 $f(x)$,第 $i$ 个点的坐标为 $(x_i,y_i)$,那么这个多项式在 $k$ 处的取值为:
$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-x_j}{x_i-x_j}$$
时间复杂度:$O(n^2)$
拉格朗日插值板子
const int maxn = 2005;
int n;
ll k;
ll x[maxn], y[maxn];
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
ll solve(ll k) {
ll ans = 0;
for (int i = 1; i <= n; i++) {
ll nu = 1, de = 1;
for (int j = 1; j <= n; j++) {
if (j == i) continue;
nu = nu * (k - x[j] + mod) % mod;
de = de * (x[i] - x[j] + mod) % mod;
}
de = qpow(de, mod-2);
ans = (ans + y[i] * nu % mod * de % mod) % mod;
}
return ans;
}
int main() {
cin >> n >> k;
for (int i = 1; i <= n; i++) cin >> x[i] >> y[i];
ll ans = solve(k);
cout << ans << endl;
}
$x$ 坐标为连续整数的拉格朗日插值
如果 $x$ 坐标是连续的整数,那么我们可以在 $O(n)$ 求出 $f(k)$。
我们假设 $x_i = i$,那么:
$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-x_j}{x_i-x_j}$$
可以转化为:
$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j} \frac{k-j}{i-j}$$
分类讨论 $j < i$ 和 $j > i$ 的情况:
$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{j=1}^{i-1} \frac{k-j}{i-j} \prod\limits_{j=i+1}^n \frac{k-j}{i-j}$$
所以可得:
$$f(k) = \sum\limits_{i=1}^ny_i \frac{1}{(i-1)!} (k-1)(k-2)…(k-(i-1)) \frac{1}{-1*-2*…*(-(n-i))} (k-i)(k-(i+1))…(k-n)$$
化简得到:
$$f(k) = \sum\limits_{i=1}^n \frac{y_i}{(-1)^{n-i} *i! * (n-i)!} \prod\limits_{j=1}^{i-1}(k-j) \prod\limits_{j=i+1}^n(k-j)$$
• 注意,后面这部分不能变成 $\frac{\prod\limits_{j=1}^{n}(k-j)}{k-i}$,因为我们无法确定 $k-i \neq 0$。
对于 $\prod\limits_{j=1}^{i-1}(k-j) \prod\limits_{j=i+1}^n(k-j)$,我们只要预处理出来一个前缀积和后缀积即可。
所以对于每一项(每个 $i$),我们都可以 $O(1)$ 时间算出对应值,总复杂度 $O(n)$。
重心拉格朗日插值
用于解决动态加点的问题。
利用重心拉格朗日插值,每加入一个新的点,我们可以在 $O(n)$ 时间内算出新的多项式。
$$f(k) = \sum\limits_{i=1}^ny_i \prod\limits_{i\neq j}^n \frac{k-x_j}{x_i-x_j}$$
我们设 $g(k) = \prod\limits_{i=1}^n(k-x_i)$,则有:
$$f(k) = g(k)\sum\limits_{i=1}^n \frac{1}{k-x_i} \prod\limits_{i\neq j}^n \frac{y_i}{x_i-x_j}$$
设 $t_i = \frac{y_i}{\prod\limits_{j \neq i}^n (x_i-x_j)}$,则有:
$$f(k) = g(k) \sum\limits_{i=1}^n \frac{t_i}{k-x_i}$$
所以每次添加一个新的点 $(x_{n+1}, y_{n+1})$ 时:
- 重新计算一下所有的 $t_i = t_i * \frac{1}{x_i - x_{n+1}}, i \in [1,n]$。
- 计算 $t_{n+1}$。
- 计算 $g(k)$。
• 注意,在求 $f(k)$ 时需要先判断一下是否有 $k=x_i$,有的话直接返回 $y_i$。(这个特判仅需要在重心拉格朗日插值中进行)。
重心拉格朗日插值板子
int Q;
ll x[maxn], y[maxn], t[maxn];
int n;
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
ll solve(ll k) {
for (int i = 1; i <= n; i++) {
if (x[i] == k) // 需要特判是否 k = x[i]
return y[i];
}
ll ans = 0;
ll g = 1;
ll sum = 0;
for (int i = 1; i <= n; i++) {
g = g * (k - x[i] + mod) % mod;
sum = (sum + t[i] * qpow((k - x[i] + mod) % mod, mod - 2) % mod) % mod;
}
ans = g * sum % mod;
return ans;
}
int main() {
cin >> Q;
while (Q--) {
int op; cin >> op;
if (op == 1) { // 添加 (xx,yy) 的点
ll xx, yy;
cin >> xx >> yy;
n++;
x[n] = xx, y[n] = yy;
ll de = 1;
for (int j = 1; j < n; j++) {
de = de * (x[n] - x[j] + mod) % mod; // 更新 t[n]
t[j] = t[j] * qpow((x[j] - x[n] + mod) % mod, mod - 2) % mod;
}
de = qpow(de, mod - 2);
t[n] = y[n] * de % mod;
} else {
ll k; cin >> k; // 求 f(k)
cout << solve(k) << endl;
}
}
}
常用模型
- $\sum\limits_{i=1}^k i^n = 1^n+2^n+…k^n$ 是一个多项式 $f(k)$,其中 $deg(f)=n+1$(意味着需要 $n+2$ 个点进行插值)。
例题
例1 CF622F The Sum of the k-th Powers
题意
给定 $k,n$,求:
$$\sum\limits_{i=1}^k i^n = 1^n+2^n+…k^n$$
其中,$0 \leq k \leq 10^9, 1 \leq n \leq 10^6$。
题解
上面说了这是一个多项式 $f(k)$,其中 $deg(f)=n+1$,意味着需要 $n+2$ 个点进行插值。
又因为 $x_i = i, y_i = \sum\limits_{j=1}^i j^n$,我们直接按照上述 $O(n)$ 的连续整数来插值即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+5;
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
ll k,n;
ll fac[maxn], inv_fac[maxn], t[maxn];
ll y[maxn], pre[maxn], suf[maxn];
int main() {
fastio;
cin >> k >> n;
n += 2;
fac[0] = inv_fac[0] = 1;
pre[0] = suf[n+1] = 1;
for (ll i = 1; i <= n; i++) {
fac[i] = fac[i-1] * i % mod;
y[i] = (y[i-1] + qpow(i, n-2)) % mod;
pre[i] = pre[i-1] * (k - i + mod) % mod;
}
for (ll i = n; i >= 1; i--) suf[i] = suf[i+1] * (k - i + mod) % mod;
inv_fac[n] = qpow(fac[n], mod-2);
for (ll i = n-1; i >= 1; i--) {
inv_fac[i] = inv_fac[i+1] * (i+1) % mod;
}
ll ans = 0;
for (ll i = 1; i <= n; i++) {
int flag = ((n-i) & 1 ? -1 : 1);
ll res = y[i] % mod * inv_fac[i-1] % mod * inv_fac[n-i] % mod * pre[i-1] % mod * suf[i+1] % mod;
res = (res * flag + mod) % mod;
ans = (ans + res) % mod;
}
cout << ans << endl;
}
例2 CF1731F. Function Sum
题意
如果我们有一个正整数数组 $a_1,a_2…,a_n$,定义 $ls(i)$ 为:$i$ 的左边比它小的数量,$gr(i)$ 为:$i$ 的右边比它大的数量。
如果对于index $i$,$ls(i) < gr(i)$,说明 $i$ 是good的。
定义 $f(a)$($a$ 是一个数组 $a_1,a_2…,a_n$)的值为:所有的 good $i$ 的对应 $a_i$ 的和。
现给定 $n, k$,求所有长度为 $n$,保证 $\forall i \in [1,n], a_i \in [1,k]$ 的所有可能的数组 $a$,$f(a)$ 的和。
其中,$n \in [1, 50], k \in [2, 998244353)$,答案对 $998244353$ 取模。
题解
典中典之:看见 “所有满足xx条件” 就想到贡献。
考虑每一个位置 $i$ 的每一个可能的值 $j$ 对答案会贡献多少?
我们设 $f(i,j)$ 为:如果 $a_i = j$,那么满足 $ls(i) < gr(i)$ 的数组数量。
则 $j$ 这个数的贡献就是 $j * f(i,j)$。
形式化的,我们最终要求的答案就是
$$\sum\limits_{i=1}^n \sum\limits_{j=1}^k j * f(i,j)$$
怎么求 $f(i,j)$?
我们枚举 $ls(i)$ 和 $gr(i)$ 的值,设 $k_1 = ls(i), k_2 = gr(i)$,有:
$$f(i,j) = \sum\limits_{k_1=0}^{i-1} \sum\limits_{k_2 = k_1+1}^{n-i} C_{i-1}^{k_1} (j-1)^{k_1}(k-j+1)^{i-1-k_1} C_{n-i}^{k_2}(k-j)^{k_2}(j)^{n-i-k_2}$$
简单解释一下:
$C_{i-1}^{k_1} (j-1)^{k_1}(k-j+1)^{i-1-k_1}$:其中 $C_{i-1}^{k_1}$ 代表从 $i$ 左边元素里选出 $k_1$ 个比它小的,而这 $k_1$ 个元素取的值只要 $<j$ 就可以了,所以有 $(j-1)^{k_1}$ 种,而剩下的 $(i-1-k_1)$ 个元素则只要满足 $\geq j$ 即可,所以有 $(k-j+1)^{i-1-k_1}$ 种。
$C_{n-i}^{k_2}(k-j)^{k_2}(j)^{n-i-k_2}$ 同理。
注意到 $n \leq 50$,所以枚举 $i$ 是没问题,但是 $k$ 非常大所以不能枚举。
但是,我们发现如果固定了 $i$,那么就只有一个变量 $j$ 了。
看一下 $\sum\limits_{j=1}^k j * f(i,j)$ 中,$j$ 的最高系数是 $1 + k_1 + (i-1-k_1) + k_2 + (n-i-k_2) = n$。
所以 $\sum\limits_{j=1}^k j * f(i,j)$ 是一个多项式,参考上面的常用模型,$\sum\limits_{j=1}^k a_j j^n$ 是一个 $deg = n+1$ 的多项式,需要 $(n+2)$ 个点来插值。
插值的点,我们就选 $j’ = 1…55$,$x_{j’} = j'$,而 $y_{j’} = \sum\limits_{j'=1}^j f(i,j’)$。插值结束后可以得到一个多项式 $h(x)$,直接求 $h(k)$ 的值即可。
如果 $k \leq 55$ 使得插值的点数量不够呢?那对于较小的 $k$ 我们直接暴力计算就可以了。
代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
template<class T>
T qpow(T a, int b) {
T res = 1;
while (b) {
if (b & 1) res *= a;
a *= a;
b >>= 1;
}
return res;
}
int norm(int x) {
if (x < 0) {
x += mod;
}
if (x >= mod) {
x -= mod;
}
return x;
}
struct Z {
int x;
Z(int x = 0) : x(norm(x)) {}
Z(ll x) : x(norm(x % mod)) {}
int val() const {
return x;
}
Z operator-() const {
return Z(norm(mod - x));
}
Z inv() const {
assert(x != 0);
return qpow(*this, mod - 2);
}
Z &operator*=(const Z &rhs) {
x = (ll)(x) * rhs.x % mod;
return *this;
}
Z &operator+=(const Z &rhs) {
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs) {
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs) {
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs) {
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs) {
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs) {
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs) {
Z res = lhs;
res /= rhs;
return res;
}
friend std::istream &operator>>(std::istream &is, Z &a) {
ll v;
is >> v;
a = Z(v);
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Z &a) {
return os << a.val();
}
};
const int maxn = 1e4+5;
const int maxm = 1e5+5;
const int M = 1e4;
int n, k;
Z x[maxn], y[maxn];
Z solve(Z k) {
Z ans = 0;
for (int i = 1; i <= 55; i++) {
Z nu = 1, de = 1;
for (int j = 1; j <= 55; j++) {
if (j == i) continue;
nu = nu * (k - x[j]);
de = de * (x[i] - x[j]);
}
de = de.inv();
ans = ans + y[i] * nu * de;
}
return ans;
}
Z fac[maxm], inv_fac[maxm];
Z C(Z i, Z j) {
return fac[i.val()] * inv_fac[j.val()] * inv_fac[(i-j).val()];
}
// brute force calculate f(i, j)
Z cal(Z i, Z j) {
Z ans = 0;
for (Z k1 = 0; k1.val() <= (i-1).val(); k1 = k1+1) {
for (Z k2 = k1+1; k2.val() <= (n-i).val(); k2 = k2+1) {
ans = ans + C(i-1, k1) * qpow(j-1, k1.val()) * qpow(k-j+1, (i-1-k1).val()) * C(n-i, k2) * qpow(k-j, k2.val()) * qpow(j, (n-i-k2).val());
}
}
return ans;
}
int main() {
cin >> n >> k;
fac[0] = 1;
for (int i = 1; i <= M; i++) fac[i] = fac[i-1] * i;
inv_fac[M] = fac[M].inv();
for (int i = M-1; i >= 0; i--) inv_fac[i] = inv_fac[i+1] * (i+1);
if (k <= 55) {
Z ans = 0;
for (Z i = 1; i.val() <= n; i = i + 1) {
for (Z j = 1; j.val() <= k; j = j + 1) {
ans = ans + (j * cal(i, j));
}
}
cout << ans.val() << endl;
} else {
Z ans = 0;
for (Z i = 1; i.val() <= n; i = i + 1) {
// 插值 f(i,j)
for (int j = 1; j <= 55; j++) {
x[j] = j;
y[j] = y[j-1] + cal(i, j) * j;
}
ans += solve(k);
}
cout << ans << endl;
}
}