整体二分
Contents
介绍
整体二分是一种思想,用于同时二分多组询问,在题目满足以下条件时可以用:
- 有多组询问,每组询问可以通过二分解决。
- 询问离线。
- 询问之间互相独立。
- 修改之间互相独立。
一般整体二分是递归写法,用一个 solve(int l, int r, int L, int R)
来解决。
其中,$[l,r]$ 是二分答案的区间,$[L,R]$ 是询问的区间。
也就是说,当前在编号为 $[L,R]$ 的询问的答案一定在 $[l,r]$ 之间。
然后令 mid = (l+r) >> 1
,判断 $[L,R]$ 的这些询问的答案与 mid
的关系,如果比它小,就将这些询问重新排列后放在左边,否则放在右边,然后继续递归。
时间复杂度
整体二分之所以比一个个二分的复杂度更加优秀,主要原因在于利用了每一层二分时,总操作数量是 $O(n)$ 级别的。对于每个询问分开二分,那么每个询问都需要 $O(n)$ 的操作次数,就会导致复杂度爆炸。
而为了保证每一层的总操作数量为 $O(n)$ 级别,这意味着在二分到 $[l,r]$ 这个区间时,必须 只考虑 在 $[l,r]$ 这个区间内的元素。要做到这一点,在向下递归之前,需要将 $[l,mid]$ 区间的贡献先全部计算完毕,并且清空任何数据结构中的贡献,以保证在下一层递归时只考虑 $[l,r]$。
注意事项
- 每次递归时,如果当前值域是 $[l,r]$,那么只考虑值在 $[l,r]$ 内的元素,剩下的元素就算能造成贡献也直接忽略不计,因为之前一定统计过了!
- 向下递归前,保证数据结构一定是清空状态(利用操作的reverse来清空)。
例题
例1 洛谷P3527 [POI2011] MET-Meteors
题意
有 $n$ 个国家,$m$ 个太空站。第 $i$ 个太空站会属于第 $o_i$ 个国家,第 $i$ 个国家希望收集 $p_i$ 个陨石。
现在有 $k$ 场陨石雨,每次落雨会向 $[l_i,r_i]$ 的太空站提供 $a_i$ 个陨石。注意太空站是环形分布的,所以可能存在 $l_i > r_i$ 的情况。
对于每个国家,回答在第几次陨石雨后就能收集到足够的陨石,如果所有陨石雨结束后仍然无法满足,输出 NIE
。
其中,$n,m,k \leq 3 \times 10^5, a_i,p_i \in [1,10^9], l_i,r_i \in [1,m]$。
题解
首先明确本题的询问是什么,是:每一个国家在第几次陨石雨后能够满足条件,所以我们二分的答案区间应该是 “第几次“。
所以 solve(l,r,L,R)
中的 $l,r$ 就代表 $[L,R]$ 这些区间内的询问(代表的国家)在第 $x$ 次陨石雨后能够满足条件,其中 $x \in [l,r]$。
然后考虑令 mid = (l+r) >> 1
,怎么能够验证某个国家的答案是 $< mid$ 还是 $\geq mid$ 呢?
给定 $mid$,我们可以先让 $[l,mid]$ 这一段的落雨降下来,然后对于每个国家,暴力检查这个国家旗下的所有太空站由于 $[l,mid]$ 这一段的落雨收集到的陨石数量,如果 $\geq$ 所需要的,就可以向左递归了,否则,将需要的数量减去这一段的贡献,然后向右递归。
那么这个降下落雨的过程,实际上是 区间修改,然后每个国家暴力检查旗下的每一个太空站,实际上是 单点查询。区间修改可以用差分数组 $O(1)$ 解决,而单点查询用差分数组的前缀和解决即可,所以用树状数组即可 $O(\log m)$ 解决。
时间复杂度:
一共有 $\log k$ 层,考虑到每个询问都一定会递归到最底层,这意味着每一层的询问数量和均为 $n$(也就是所有国家)。
而对于每一个询问的国家,都会暴力查询它旗下的所有太空站,所以每一层都会查询所有的太空站(总共 $m$ 个),而每个太空站都需要进行一次树状数组查询,所以每层的复杂度是 $O(m \log m)$。
总复杂度为 $O(m \log m \log k)$。
几个注意点:
- 因为向右递归时减去了 $[l,mid]$ 对应的贡献,为了不影响下一层的答案,在向下递归前,要将树状数组清空。
- 极端情况下,只有 $1$ 个国家,每次陨石雨降下 $10^9$ 个陨石到 $[1,m]$,这样这个国家收集到的数量为 $10^9 * m * m > 10^{18}$,会爆 long long,记得在检查陨石量是否超过所需要的时,只要超过就直接返回即可。
- 只要在最后加一个降下
inf
数量陨石的雨,就可以判断哪些国家收集不够了。 - 清空树状数组不要用
memset
,而应该将所有的修改操作 reverse 回去。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5+5005;
const int maxm = 1e5+55;
int n, m, k, ans[maxn];
struct BIT {
ll tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= maxn-5) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
ll query(int p) {
ll ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr;
struct Node {
int l, r;
ll a;
} upd[maxn];
vector<int> adj[maxn]; // adj[i]: 国家i拥有的所有基站的编号 [1...m]
struct Query {
ll need;
int id;
} q[maxn], ql[maxn], qr[maxn];
void update(int l, int r, ll a) {
if (l <= r) {
tr.update(l, a);
tr.update(r+1, -a);
} else {
tr.update(l, a);
tr.update(m+1, -a);
tr.update(1, a);
tr.update(r+1, -a);
}
}
// 查询 q[i]这个位置的国家此时已经收集到的量
ll query(int i) {
ll sum = 0;
for (int x : adj[q[i].id]) {
sum += tr.query(x);
if (sum >= q[i].need) return sum; // 防止爆 long long
}
return sum;
}
// 二分的答案范围为 [l,r],意思是只考虑第 [l,r] 次落雨 (修改操作), 并且 [L,R] 的这些询问的答案必然在 [l,r] 之间
// 操作的区间范围为 [L,R],只考虑当前编号在 [L,R] 的询问, 也就是 q[L...R].id
void solve(int l, int r, int L, int R) {
if (L > R) return;
if (l == r) {
for (int i = L; i <= R; i++) {
ans[q[i].id] = l;
}
return;
}
int mid = (l+r) >> 1;
// 先让 [l, mid] 的落雨下来
for (int i = l; i <= mid; i++) {
int x = upd[i].l, y = upd[i].r;
update(x, y, upd[i].a);
}
int lcnt = 0, rcnt = 0;
for (int i = L; i <= R; i++) {
ll sum = query(i);
if (q[i].need <= sum) {
ql[++lcnt] = q[i];
} else {
q[i].need -= sum;
qr[++rcnt] = q[i];
}
}
for (int i = l; i <= mid; i++) {
int x = upd[i].l, y = upd[i].r;
update(x, y, -upd[i].a); // 清空树状数组
}
// 有 lcnt 个在左边
int head = L-1;
for (int i = 1; i <= lcnt; i++) q[++head] = ql[i];
for (int i = 1; i <= rcnt; i++) q[++head] = qr[i];
solve(l, mid, L, L + lcnt - 1);
solve(mid+1, r, L + lcnt, R);
}
int main() {
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int o; cin >> o;
adj[o].push_back(i);
}
for (int i = 1; i <= n; i++) cin >> q[i].need;
cin >> k;
for (int i = 1; i <= k; i++) {
int l, r, a; cin >> l >> r >> a;
upd[i] = {l, r, a};
}
upd[++k] = {1, m, (int)(1e9+7)};
solve(1, k, 1, n);
for (int i = 1; i <= n; i++) {
if (ans[i] == k) cout << "NIE\n";
else cout << ans[i] << "\n";
}
}
例2 洛谷P3834 区间第k小
题意
给定 $N$ 个整数 $a_1,a_2,…,a_n$,和 $m$ 个询问,每次询问 $[L,R]$ 之间的第 $k$ 小值。保证询问合法。
其中,$1 \leq n,m \leq 2 \times 10^5, |a_i| \leq 10^9$
题解
主席树模版题!整体二分写起来更加简单!
首先明确二分的是答案,那么先考虑:对于一个询问来说,给定一个 mid
,想知道这个询问的答案是 $\leq mid$ 还是 $> mid$,怎么办?
注意到我们只关心哪些数字 $> mid$,哪些数字 $\leq mid$。所以 $\leq mid$ 的数字可以全部看作 $1$,而 $> mid$ 的全部看作 $0$。
然后对于这个询问 $[L,R]$,我们判断 $sum[L…R]$ 是否 $\geq k$ 即可,这个就是单点修改,区间求和,树状数组即可。
但这样有个问题,我们不能对于每个询问,都把整个数组处理一遍,变成 $0,1$ 吧?注意到值域是 $[l,r]$ 意味着我们只考虑这些元素,于是我们只把所有值 $\in [l,mid]$ 的 $a_i$ 变成 $1$ 即可。这样保证了整体二分的复杂度,每层仍然是 $O(n)$ 的。
• 当然,我们需要先离散化才能这么做,否则 $[l,mid]$ 会非常大。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+50;
const int maxm = 1e5+55;
struct BIT {
ll tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= maxn-5) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
ll query(int p) {
ll ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr;
int n, m;
int a[maxn];
struct Query {
int l, r, k, id;
} q[maxn], ql[maxn], qr[maxn];
int ans[maxn];
map<int, int> mp; // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn];
vector<int> pos[maxn]; // pos[x]: 值为x的所有index
void solve(int l, int r, int L, int R) {
if (L > R) return;
if (l == r) {
for (int i = L; i <= R; i++) {
ans[q[i].id] = l;
}
return;
}
int mid = (l+r) >> 1;
// 还是先考虑 [l,mid]
for (int i = l; i <= mid; i++) {
for (int p : pos[i]) {
tr.update(p, 1);
}
}
int lcnt = 0, rcnt = 0;
for (int i = L; i <= R; i++) {
int sum = tr.query(q[i].r) - tr.query(q[i].l-1);
if (sum >= q[i].k) {
ql[++lcnt] = q[i];
} else {
q[i].k -= sum;
qr[++rcnt] = q[i];
}
}
int j = L - 1;
for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];
for (int i = l; i <= mid; i++) {
for (int p : pos[i]) {
tr.update(p, -1);
}
}
solve(l, mid, L, L + lcnt - 1);
solve(mid+1, r, L + lcnt, R);
}
int main() {
fastio;
cin >> n >> m;
set<int> se;
for (int i = 1; i <= n; i++) {
cin >> a[i];
se.insert(a[i]);
}
int M = 0;
for (auto x : se) {
mp[x] = ++M;
rev_mp[M] = x;
}
for (int i = 1; i <= n; i++) {
a[i] = mp[a[i]];
pos[a[i]].push_back(i);
}
// 被map到 [1,M] 之间了
for (int i = 1; i <= m; i++) {
int l, r, k; cin >> l >> r >> k;
q[i] = {l, r, k, i};
}
solve(1, M, 1, m);
for (int i = 1; i <= m; i++) {
cout << rev_mp[ans[i]] << "\n";
}
}
例3 洛谷P2617 Dynamic Rankings
题意
给定一个长度为 $n$ 的序列 $a_1,a_2,…,a_n$,有两种询问,询问共 $m$ 个:
- $Q~l~r~k$:询问 $[l,r]$ 之间第 $k$ 小的数。
- $C~x~y$:将 $a_x$ 改为 $y$。
其中,$n,m \leq 10^5, a_i, y \in [0,10^9]$,保证每个询问合法。
题解
上一题的加强版,多了个修改操作。
对于修改操作,我们可以将它拆分成 删除一个数,再 加入一个数。并且由于事先知道了所有询问(离线),所以我们可以预处理出每一个修改操作在哪个位置删除了哪个数,加入了哪个数。
这样的话,删除一个数/加入一个数,就可以和普通的询问一样,调成只和值域有关的询问了。
于是,有三种询问:
- 普通的区间询问
- 在某个位置加入一个数
- 在某个位置删除一个数
这三种询问可以放在同一个询问序列里面进行处理。
在 solve(l, r, L, R)
时,同样只考虑值域只在 $[l,r]$ 内的数。
然后在决定哪些询问去左边/右边时,很明显对于询问 $2,3$,根据加入/删除的这个数的值是在 $[l,mid]$ 还是 $[mid+1, r]$ 来决定左右。
对于询问 $1$,注意到有些询问 $1$ 之前可能是有一些询问 $2,3$ 的,那么我们按顺序处理询问 $2,3$,只考虑所有值域在 $[l,mid]$ 的询问 $2,3$ 对树状数组进行update,然后处理到询问 $1$ 的时候就跟上一题一样。
• 为什么在我们把询问本身的顺序更改后,这样做仍然是正确的?
因为询问只是根据值域分开了,而对于同一个值域中的任意两个询问,它们的相对顺序是不改变的,所以这保证了询问 $2,3$ 和 询问 $1$ 之间的相对顺序正确。
• 对于一开始给定的序列,可以把它们看作:在位置 $1,2,…,n$ 分别加入 $a_1,a_2,…,a_n$ 这些数,然后也同样把它们当作询问一起处理了。
• 最后注意,本题在离散化的时候,要把所有修改的数 $y$ 也一起离散化,因为它也属于值域的一部分。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+50;
const int maxm = 1e5+55;
struct BIT {
ll tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= maxn-5) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
ll query(int p) {
ll ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr;
int n, m;
int a[maxn];
struct Query {
int l, r, k, id, type;
// type = 1: query [l,r,k], id
// type = 2: 在位置l 加入 a_k
// type = 3: 在位置l 删掉 a_k
} q[maxn], ql[maxn], qr[maxn];
int ans[maxn];
map<int, int> mp; // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn];
void solve(int l, int r, int L, int R) {
if (L > R) return;
if (l == r) { // 注意到在递归到最底层的时候就不考虑 加入/删除 数了
for (int i = L; i <= R; i++) {
if (q[i].type == 1)
ans[q[i].id] = l;
}
return;
}
int mid = (l+r) >> 1;
int lcnt = 0, rcnt = 0;
for (int i = L; i <= R; i++) {
if (q[i].type == 1) {
int sum = tr.query(q[i].r) - tr.query(q[i].l-1);
if (sum >= q[i].k) {
ql[++lcnt] = q[i];
} else {
q[i].k -= sum;
qr[++rcnt] = q[i];
}
} else {
if (q[i].k >= l && q[i].k <= mid) {
if (q[i].type == 2) tr.update(q[i].l, 1);
if (q[i].type == 3) tr.update(q[i].l, -1);
}
if (q[i].k >= l && q[i].k <= mid) ql[++lcnt] = q[i];
if (q[i].k > mid && q[i].k <= r) qr[++rcnt] = q[i];
}
}
for (int i = L; i <= R; i++) {
if (q[i].k >= l && q[i].k <= mid) {
if (q[i].type == 2) tr.update(q[i].l, -1);
if (q[i].type == 3) tr.update(q[i].l, 1);
}
}
int j = L - 1;
for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];
solve(l, mid, L, L + lcnt - 1);
solve(mid+1, r, L + lcnt, R);
}
int main() {
fastio;
cin >> n >> m;
int id = 0;
for (int i = 1; i <= n; i++) {
cin >> a[i];
q[++id] = {i, -1, a[i], -1, 2};
}
for (int i = 1; i <= m; i++) {
char c; cin >> c;
if (c == 'Q') {
int l, r, k; cin >> l >> r >> k;
q[++id] = {l, r, k, i, 1};
} else {
int x, y; cin >> x >> y;
q[++id] = {x, -1, a[x], -1, 3};
q[++id] = {x, -1, y, -1, 2};
a[x] = y;
}
}
// 注意所有询问都结束后才开始离散化
set<int> se;
for (int i = 1; i <= id; i++) se.insert(q[i].k);
int M = 0;
for (auto x : se) {
mp[x] = ++M;
rev_mp[M] = x;
}
for (int i = 1; i <= n; i++) a[i] = mp[a[i]];
for (int i = 1; i <= id; i++) {
if (q[i].type != 1)
q[i].k = mp[q[i].k];
}
memset(ans, -1, sizeof(ans));
// 被map到 [1,M] 之间了
solve(1, M, 1, id);
for (int i = 1; i <= id; i++) {
if (ans[i] >= 0)
cout << rev_mp[ans[i]] << "\n";
}
}
例4 洛谷P1527 [国家集训队] 矩阵乘法
题意
给定一个 $n \times n$ 的矩阵,$m$ 次询问,每次询问一个子矩阵的第 $k$ 小的数。
其中,$n \leq 500, m \leq 6 \times 10^4, a_{i,j} \in [0,10^9]$。
题解
除了一维变成二维,看起来和 例2 一毛一样
事实上就是一样的,就连 solve(l, r, L, R)
都不需要更改,因为这些参数和问题的维度无关。
只不过在判断一个询问是在左边还是右边的时候,需要查询一个子矩阵的和是否 $\geq k$ 了。
也就是需要一个 区间查询和,单点修改 的数据结构,用二维树状数组即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 500+15;
const int maxm = 1e5+55;
struct BIT2 {
int tr[maxn][maxn];
inline int lowbit(int x) { return x&-x; }
void update(int x, int y, int val) {
int tmp = y;
while (x <= maxn-5) {
y = tmp;
while (y <= maxn-5) {
tr[x][y] += val;
y += lowbit(y);
}
x += lowbit(x);
}
}
// 查询 [1...x][1...y]
int query(int x, int y) {
int tmp = y;
ll ans = 0;
while (x > 0) {
y = tmp;
while (y > 0) {
ans += tr[x][y];
y -= lowbit(y);
}
x -= lowbit(x);
}
return ans;
}
int query(int x1, int y1, int x2, int y2) {
return query(x2, y2) - query(x1-1, y2) - query(x2, y1-1) + query(x1-1, y1-1);
}
} tr;
int n, m;
int a[maxn][maxn];
struct Query {
int x1, y1, x2, y2, k, id;
} q[maxm], ql[maxm], qr[maxm];
int ans[maxm];
map<int, int> mp; // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn * maxn];
vector<pii> pos[maxn * maxn]; // pos[x]: 值为x的所有index
void solve(int l, int r, int L, int R) {
if (L > R) return;
if (l == r) {
for (int i = L; i <= R; i++) {
ans[q[i].id] = l;
}
return;
}
int mid = (l+r) >> 1;
// 还是先考虑 [l,mid]
for (int i = l; i <= mid; i++) {
for (auto [x, y] : pos[i]) {
tr.update(x, y, 1);
}
}
int lcnt = 0, rcnt = 0;
for (int i = L; i <= R; i++) {
auto [x1, y1, x2, y2, k, _] = q[i];
int sum = tr.query(x1, y1, x2, y2);
if (sum >= q[i].k) {
ql[++lcnt] = q[i];
} else {
q[i].k -= sum;
qr[++rcnt] = q[i];
}
}
int j = L - 1;
for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];
for (int i = l; i <= mid; i++) {
for (auto [x, y] : pos[i]) {
tr.update(x, y, -1);
}
}
solve(l, mid, L, L + lcnt - 1);
solve(mid+1, r, L + lcnt, R);
}
int main() {
fastio;
cin >> n >> m;
set<int> se;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
cin >> a[i][j];
se.insert(a[i][j]);
}
}
int M = 0;
for (auto x : se) {
mp[x] = ++M;
rev_mp[M] = x;
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
a[i][j] = mp[a[i][j]];
pos[a[i][j]].push_back({i,j});
}
}
// 被map到 [1,M] 之间了
for (int i = 1; i <= m; i++) {
int x1, y1, x2, y2, k; cin >> x1 >> y1 >> x2 >> y2 >> k;
q[i] = {x1, y1, x2, y2, k, i};
}
solve(1, M, 1, m);
for (int i = 1; i <= m; i++) {
cout << rev_mp[ans[i]] << "\n";
}
}