介绍

整体二分是一种思想,用于同时二分多组询问,在题目满足以下条件时可以用:

  1. 有多组询问,每组询问可以通过二分解决。
  2. 询问离线。
  3. 询问之间互相独立。
  4. 修改之间互相独立。

一般整体二分是递归写法,用一个 solve(int l, int r, int L, int R) 来解决。

其中,$[l,r]$ 是二分答案的区间,$[L,R]$ 是询问的区间。

也就是说,当前在编号为 $[L,R]$ 的询问的答案一定在 $[l,r]$ 之间。

然后令 mid = (l+r) >> 1,判断 $[L,R]$ 的这些询问的答案与 mid 的关系,如果比它小,就将这些询问重新排列后放在左边,否则放在右边,然后继续递归。

时间复杂度

整体二分之所以比一个个二分的复杂度更加优秀,主要原因在于利用了每一层二分时,总操作数量是 $O(n)$ 级别的。对于每个询问分开二分,那么每个询问都需要 $O(n)$ 的操作次数,就会导致复杂度爆炸。

而为了保证每一层的总操作数量为 $O(n)$ 级别,这意味着在二分到 $[l,r]$ 这个区间时,必须 只考虑 在 $[l,r]$ 这个区间内的元素。要做到这一点,在向下递归之前,需要将 $[l,mid]$ 区间的贡献先全部计算完毕,并且清空任何数据结构中的贡献,以保证在下一层递归时只考虑 $[l,r]$。

注意事项

  1. 每次递归时,如果当前值域是 $[l,r]$,那么只考虑值在 $[l,r]$ 内的元素,剩下的元素就算能造成贡献也直接忽略不计,因为之前一定统计过了!
  2. 向下递归,保证数据结构一定是清空状态(利用操作的reverse来清空)。

例题

例1 洛谷P3527 [POI2011] MET-Meteors

题意

有 $n$ 个国家,$m$ 个太空站。第 $i$ 个太空站会属于第 $o_i$ 个国家,第 $i$ 个国家希望收集 $p_i$ 个陨石。

现在有 $k$ 场陨石雨,每次落雨会向 $[l_i,r_i]$ 的太空站提供 $a_i$ 个陨石。注意太空站是环形分布的,所以可能存在 $l_i > r_i$ 的情况。

对于每个国家,回答在第几次陨石雨后就能收集到足够的陨石,如果所有陨石雨结束后仍然无法满足,输出 NIE

其中,$n,m,k \leq 3 \times 10^5, a_i,p_i \in [1,10^9], l_i,r_i \in [1,m]$。

题解

首先明确本题的询问是什么,是:每一个国家在第几次陨石雨后能够满足条件,所以我们二分的答案区间应该是 “第几次“。

所以 solve(l,r,L,R) 中的 $l,r$ 就代表 $[L,R]$ 这些区间内的询问(代表的国家)在第 $x$ 次陨石雨后能够满足条件,其中 $x \in [l,r]$。

然后考虑令 mid = (l+r) >> 1,怎么能够验证某个国家的答案是 $< mid$ 还是 $\geq mid$ 呢?

给定 $mid$,我们可以先让 $[l,mid]$ 这一段的落雨降下来,然后对于每个国家,暴力检查这个国家旗下的所有太空站由于 $[l,mid]$ 这一段的落雨收集到的陨石数量,如果 $\geq$ 所需要的,就可以向左递归了,否则,将需要的数量减去这一段的贡献,然后向右递归。

那么这个降下落雨的过程,实际上是 区间修改,然后每个国家暴力检查旗下的每一个太空站,实际上是 单点查询。区间修改可以用差分数组 $O(1)$ 解决,而单点查询用差分数组的前缀和解决即可,所以用树状数组即可 $O(\log m)$ 解决。


时间复杂度:

一共有 $\log k$ 层,考虑到每个询问都一定会递归到最底层,这意味着每一层的询问数量和均为 $n$(也就是所有国家)。

而对于每一个询问的国家,都会暴力查询它旗下的所有太空站,所以每一层都会查询所有的太空站(总共 $m$ 个),而每个太空站都需要进行一次树状数组查询,所以每层的复杂度是 $O(m \log m)$。

总复杂度为 $O(m \log m \log k)$。


几个注意点:

  1. 因为向右递归时减去了 $[l,mid]$ 对应的贡献,为了不影响下一层的答案,在向下递归前,要将树状数组清空。
  2. 极端情况下,只有 $1$ 个国家,每次陨石雨降下 $10^9$ 个陨石到 $[1,m]$,这样这个国家收集到的数量为 $10^9 * m * m > 10^{18}$,会爆 long long,记得在检查陨石量是否超过所需要的时,只要超过就直接返回即可。
  3. 只要在最后加一个降下 inf 数量陨石的雨,就可以判断哪些国家收集不够了。
  4. 清空树状数组不要用 memset,而应该将所有的修改操作 reverse 回去。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5+5005;
const int maxm = 1e5+55;

int n, m, k, ans[maxn];
struct BIT {
    ll tr[maxn];
    inline int lowbit(int x) { return x & -x; }
    void update(int p, int val) {
        while (p <= maxn-5) {
            tr[p] += val;
            p += lowbit(p);
        }
    }
    // return sum[1...p]
    ll query(int p) {
        ll ans = 0;
        while (p > 0) {
            ans += tr[p];
            p -= lowbit(p);
        }
        return ans;
    }
} tr;

struct Node {
    int l, r; 
    ll a; 
} upd[maxn];
vector<int> adj[maxn];  // adj[i]: 国家i拥有的所有基站的编号 [1...m]

struct Query {
    ll need;
    int id;
} q[maxn], ql[maxn], qr[maxn];

void update(int l, int r, ll a) {
    if (l <= r) {
        tr.update(l, a);
        tr.update(r+1, -a);
    } else {
        tr.update(l, a);
        tr.update(m+1, -a);
        tr.update(1, a);
        tr.update(r+1, -a);
    }
}

// 查询 q[i]这个位置的国家此时已经收集到的量
ll query(int i) {
    ll sum = 0;
    for (int x : adj[q[i].id]) {
        sum += tr.query(x);
        if (sum >= q[i].need) return sum;  // 防止爆 long long
    }
    return sum;
}

// 二分的答案范围为 [l,r],意思是只考虑第 [l,r] 次落雨 (修改操作), 并且 [L,R] 的这些询问的答案必然在 [l,r] 之间
// 操作的区间范围为 [L,R],只考虑当前编号在 [L,R] 的询问, 也就是 q[L...R].id
void solve(int l, int r, int L, int R) {
    if (L > R) return;
    if (l == r) {
        for (int i = L; i <= R; i++) {
            ans[q[i].id] = l;
        }
        return;
    }
    int mid = (l+r) >> 1;
    // 先让 [l, mid] 的落雨下来
    for (int i = l; i <= mid; i++) {
        int x = upd[i].l, y = upd[i].r;
        update(x, y, upd[i].a);
    }

    int lcnt = 0, rcnt = 0;
    for (int i = L; i <= R; i++) {
        ll sum = query(i);
        if (q[i].need <= sum) {
            ql[++lcnt] = q[i];
        } else {
            q[i].need -= sum;
            qr[++rcnt] = q[i];
        }
    }

    for (int i = l; i <= mid; i++) {
        int x = upd[i].l, y = upd[i].r;
        update(x, y, -upd[i].a);  // 清空树状数组
    }

    // 有 lcnt 个在左边
    int head = L-1;
    for (int i = 1; i <= lcnt; i++) q[++head] = ql[i];
    for (int i = 1; i <= rcnt; i++) q[++head] = qr[i];
    solve(l, mid, L, L + lcnt - 1);
    solve(mid+1, r, L + lcnt, R);
}
int main() {
    cin >> n >> m;
    for (int i = 1; i <= m; i++) {
        int o; cin >> o;
        adj[o].push_back(i);
    }
    for (int i = 1; i <= n; i++) cin >> q[i].need;
    cin >> k;
    for (int i = 1; i <= k; i++) {
        int l, r, a; cin >> l >> r >> a;
        upd[i] = {l, r, a};
    }
    upd[++k] = {1, m, (int)(1e9+7)};
    solve(1, k, 1, n);
    for (int i = 1; i <= n; i++) {
        if (ans[i] == k) cout << "NIE\n";
        else cout << ans[i] << "\n";
    }
}

例2 洛谷P3834 区间第k小

题意

给定 $N$ 个整数 $a_1,a_2,…,a_n$,和 $m$ 个询问,每次询问 $[L,R]$ 之间的第 $k$ 小值。保证询问合法。

其中,$1 \leq n,m \leq 2 \times 10^5, |a_i| \leq 10^9$

题解

主席树模版题!整体二分写起来更加简单!

首先明确二分的是答案,那么先考虑:对于一个询问来说,给定一个 mid,想知道这个询问的答案是 $\leq mid$ 还是 $> mid$,怎么办?

注意到我们只关心哪些数字 $> mid$,哪些数字 $\leq mid$。所以 $\leq mid$ 的数字可以全部看作 $1$,而 $> mid$ 的全部看作 $0$。

然后对于这个询问 $[L,R]$,我们判断 $sum[L…R]$ 是否 $\geq k$ 即可,这个就是单点修改,区间求和,树状数组即可。

但这样有个问题,我们不能对于每个询问,都把整个数组处理一遍,变成 $0,1$ 吧?注意到值域是 $[l,r]$ 意味着我们只考虑这些元素,于是我们只把所有值 $\in [l,mid]$ 的 $a_i$ 变成 $1$ 即可。这样保证了整体二分的复杂度,每层仍然是 $O(n)$ 的。

• 当然,我们需要先离散化才能这么做,否则 $[l,mid]$ 会非常大。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+50;
const int maxm = 1e5+55;

struct BIT {
    ll tr[maxn];
    inline int lowbit(int x) { return x & -x; }
    void update(int p, int val) {
        while (p <= maxn-5) {
            tr[p] += val;
            p += lowbit(p);
        }
    }
    // return sum[1...p]
    ll query(int p) {
        ll ans = 0;
        while (p > 0) {
            ans += tr[p];
            p -= lowbit(p);
        }
        return ans;
    }
} tr;


int n, m;
int a[maxn];
struct Query {
    int l, r, k, id;
} q[maxn], ql[maxn], qr[maxn];
int ans[maxn];
map<int, int> mp;  // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn];

vector<int> pos[maxn];  // pos[x]: 值为x的所有index
void solve(int l, int r, int L, int R) {
    if (L > R) return;
    if (l == r) {
        for (int i = L; i <= R; i++) {
            ans[q[i].id] = l;
        }
        return;
    }
    int mid = (l+r) >> 1;
    // 还是先考虑 [l,mid]
    for (int i = l; i <= mid; i++) {
        for (int p : pos[i]) {
            tr.update(p, 1);
        }
    }

    int lcnt = 0, rcnt = 0;
    for (int i = L; i <= R; i++) {
        int sum = tr.query(q[i].r) - tr.query(q[i].l-1);
        if (sum >= q[i].k) {
            ql[++lcnt] = q[i];
        } else {
            q[i].k -= sum;
            qr[++rcnt] = q[i];
        }
    }

    int j = L - 1;
    for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
    for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];

    for (int i = l; i <= mid; i++) {
        for (int p : pos[i]) {
            tr.update(p, -1);
        }
    }

    solve(l, mid, L, L + lcnt - 1);
    solve(mid+1, r, L + lcnt, R);
}

int main() {
    fastio;
    cin >> n >> m;
    set<int> se;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        se.insert(a[i]);
    }
    int M = 0;
    for (auto x : se) {
        mp[x] = ++M;
        rev_mp[M] = x;
    }
    for (int i = 1; i <= n; i++) {
        a[i] = mp[a[i]];
        pos[a[i]].push_back(i);
    }
    // 被map到 [1,M] 之间了 
    for (int i = 1; i <= m; i++) {
        int l, r, k; cin >> l >> r >> k;
        q[i] = {l, r, k, i};
    }
    solve(1, M, 1, m);
    for (int i = 1; i <= m; i++) {
        cout << rev_mp[ans[i]] << "\n";
    }
}

例3 洛谷P2617 Dynamic Rankings

题意

给定一个长度为 $n$ 的序列 $a_1,a_2,…,a_n$,有两种询问,询问共 $m$ 个:

  1. $Q~l~r~k$:询问 $[l,r]$ 之间第 $k$ 小的数。
  2. $C~x~y$:将 $a_x$ 改为 $y$。

其中,$n,m \leq 10^5, a_i, y \in [0,10^9]$,保证每个询问合法。

题解

上一题的加强版,多了个修改操作。

对于修改操作,我们可以将它拆分成 删除一个数,再 加入一个数。并且由于事先知道了所有询问(离线),所以我们可以预处理出每一个修改操作在哪个位置删除了哪个数,加入了哪个数。

这样的话,删除一个数/加入一个数,就可以和普通的询问一样,调成只和值域有关的询问了。

于是,有三种询问:

  1. 普通的区间询问
  2. 在某个位置加入一个数
  3. 在某个位置删除一个数

这三种询问可以放在同一个询问序列里面进行处理。

solve(l, r, L, R) 时,同样只考虑值域只在 $[l,r]$ 内的数。

然后在决定哪些询问去左边/右边时,很明显对于询问 $2,3$,根据加入/删除的这个数的值是在 $[l,mid]$ 还是 $[mid+1, r]$ 来决定左右。

对于询问 $1$,注意到有些询问 $1$ 之前可能是有一些询问 $2,3$ 的,那么我们按顺序处理询问 $2,3$,只考虑所有值域在 $[l,mid]$ 的询问 $2,3$ 对树状数组进行update,然后处理到询问 $1$ 的时候就跟上一题一样。

• 为什么在我们把询问本身的顺序更改后,这样做仍然是正确的?

因为询问只是根据值域分开了,而对于同一个值域中的任意两个询问,它们的相对顺序是不改变的,所以这保证了询问 $2,3$ 和 询问 $1$ 之间的相对顺序正确。

• 对于一开始给定的序列,可以把它们看作:在位置 $1,2,…,n$ 分别加入 $a_1,a_2,…,a_n$ 这些数,然后也同样把它们当作询问一起处理了。

• 最后注意,本题在离散化的时候,要把所有修改的数 $y$ 也一起离散化,因为它也属于值域的一部分。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+50;
const int maxm = 1e5+55;

struct BIT {
    ll tr[maxn];
    inline int lowbit(int x) { return x & -x; }
    void update(int p, int val) {
        while (p <= maxn-5) {
            tr[p] += val;
            p += lowbit(p);
        }
    }
    // return sum[1...p]
    ll query(int p) {
        ll ans = 0;
        while (p > 0) {
            ans += tr[p];
            p -= lowbit(p);
        }
        return ans;
    }
} tr;


int n, m;
int a[maxn];
struct Query {
    int l, r, k, id, type;  
    // type = 1: query [l,r,k], id
    // type = 2: 在位置l 加入 a_k
    // type = 3: 在位置l 删掉 a_k
} q[maxn], ql[maxn], qr[maxn];
int ans[maxn];
map<int, int> mp;  // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn];

void solve(int l, int r, int L, int R) {
    if (L > R) return;
    if (l == r) {  // 注意到在递归到最底层的时候就不考虑 加入/删除 数了
        for (int i = L; i <= R; i++) {
            if (q[i].type == 1) 
                ans[q[i].id] = l;
        }
        return;
    }
    int mid = (l+r) >> 1;

    int lcnt = 0, rcnt = 0;
    for (int i = L; i <= R; i++) {
        if (q[i].type == 1) {
            int sum = tr.query(q[i].r) - tr.query(q[i].l-1);
            if (sum >= q[i].k) {
                ql[++lcnt] = q[i];
            } else {
                q[i].k -= sum;
                qr[++rcnt] = q[i];
            }
        } else {
            if (q[i].k >= l && q[i].k <= mid) {
                if (q[i].type == 2) tr.update(q[i].l, 1);
                if (q[i].type == 3) tr.update(q[i].l, -1);
            }
            if (q[i].k >= l && q[i].k <= mid) ql[++lcnt] = q[i];
            if (q[i].k > mid && q[i].k <= r) qr[++rcnt] = q[i];
        }
    }

    for (int i = L; i <= R; i++) {
        if (q[i].k >= l && q[i].k <= mid) {
            if (q[i].type == 2) tr.update(q[i].l, -1);
            if (q[i].type == 3) tr.update(q[i].l, 1);
        }
    }

    int j = L - 1;
    for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
    for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];

    solve(l, mid, L, L + lcnt - 1);
    solve(mid+1, r, L + lcnt, R);
}

int main() {
    fastio;
    cin >> n >> m;

    int id = 0;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        q[++id] = {i, -1, a[i], -1, 2};
    }

    for (int i = 1; i <= m; i++) {
        char c; cin >> c;
        if (c == 'Q') {
            int l, r, k; cin >> l >> r >> k;
            q[++id] = {l, r, k, i, 1};
        } else {
            int x, y; cin >> x >> y;
            q[++id] = {x, -1, a[x], -1, 3};
            q[++id] = {x, -1, y, -1, 2};
            a[x] = y;
        }
    }

    // 注意所有询问都结束后才开始离散化

    set<int> se;
    for (int i = 1; i <= id; i++) se.insert(q[i].k);
    int M = 0;
    for (auto x : se) {
        mp[x] = ++M;
        rev_mp[M] = x;
    }
    for (int i = 1; i <= n; i++) a[i] = mp[a[i]];
    for (int i = 1; i <= id; i++) {
        if (q[i].type != 1)
            q[i].k = mp[q[i].k];
    }

    memset(ans, -1, sizeof(ans));
    // 被map到 [1,M] 之间了 
    solve(1, M, 1, id);
    for (int i = 1; i <= id; i++) {
        if (ans[i] >= 0)
            cout << rev_mp[ans[i]] << "\n";
    }
}

例4 洛谷P1527 [国家集训队] 矩阵乘法

题意

给定一个 $n \times n$ 的矩阵,$m$ 次询问,每次询问一个子矩阵的第 $k$ 小的数。

其中,$n \leq 500, m \leq 6 \times 10^4, a_{i,j} \in [0,10^9]$。

题解

除了一维变成二维,看起来和 例2 一毛一样

事实上就是一样的,就连 solve(l, r, L, R) 都不需要更改,因为这些参数和问题的维度无关。

只不过在判断一个询问是在左边还是右边的时候,需要查询一个子矩阵的和是否 $\geq k$ 了。

也就是需要一个 区间查询和,单点修改 的数据结构,用二维树状数组即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 500+15;
const int maxm = 1e5+55;

struct BIT2 {
    int tr[maxn][maxn];
    inline int lowbit(int x) { return x&-x; }
    void update(int x, int y, int val) {
        int tmp = y;
        while (x <= maxn-5) {
            y = tmp;	
            while (y <= maxn-5) {
                tr[x][y] += val;
                y += lowbit(y);
            }
            x += lowbit(x);
        }
    }
    // 查询 [1...x][1...y]
    int query(int x, int y) {
        int tmp = y;
        ll ans = 0;
        while (x > 0) {
            y = tmp;
            while (y > 0) {
                ans += tr[x][y];
                y -= lowbit(y);
            }
            x -= lowbit(x);
        }
        return ans;
    }

    int query(int x1, int y1, int x2, int y2) {
        return query(x2, y2) - query(x1-1, y2) - query(x2, y1-1) + query(x1-1, y1-1);
    }
} tr;

int n, m;
int a[maxn][maxn];
struct Query {
    int x1, y1, x2, y2, k, id;
} q[maxm], ql[maxm], qr[maxm];
int ans[maxm];
map<int, int> mp;  // mp: val->rank, rev_mp: rank->val
int rev_mp[maxn * maxn];

vector<pii> pos[maxn * maxn];  // pos[x]: 值为x的所有index
void solve(int l, int r, int L, int R) {
    if (L > R) return;
    if (l == r) {
        for (int i = L; i <= R; i++) {
            ans[q[i].id] = l;
        }
        return;
    }
    int mid = (l+r) >> 1;
    // 还是先考虑 [l,mid]
    for (int i = l; i <= mid; i++) {
        for (auto [x, y] : pos[i]) {
            tr.update(x, y, 1);
        }
    }

    int lcnt = 0, rcnt = 0;
    for (int i = L; i <= R; i++) {
        auto [x1, y1, x2, y2, k, _] = q[i];
        int sum = tr.query(x1, y1, x2, y2);
        if (sum >= q[i].k) {
            ql[++lcnt] = q[i];
        } else {
            q[i].k -= sum;
            qr[++rcnt] = q[i];
        }
    }

    int j = L - 1;
    for (int i = 1; i <= lcnt; i++) q[++j] = ql[i];
    for (int i = 1; i <= rcnt; i++) q[++j] = qr[i];

    for (int i = l; i <= mid; i++) {
        for (auto [x, y] : pos[i]) {
            tr.update(x, y, -1);
        }
    }

    solve(l, mid, L, L + lcnt - 1);
    solve(mid+1, r, L + lcnt, R);
}

int main() {
    fastio;
    cin >> n >> m;
    set<int> se;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            cin >> a[i][j];
            se.insert(a[i][j]);
        }
    }
    int M = 0;
    for (auto x : se) {
        mp[x] = ++M;
        rev_mp[M] = x;
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            a[i][j] = mp[a[i][j]];
            pos[a[i][j]].push_back({i,j});
        }
    }
    // 被map到 [1,M] 之间了 
    for (int i = 1; i <= m; i++) {
        int x1, y1, x2, y2, k; cin >> x1 >> y1 >> x2 >> y2 >> k;
        q[i] = {x1, y1, x2, y2, k, i};
    }
    solve(1, M, 1, m);
    for (int i = 1; i <= m; i++) {
        cout << rev_mp[ans[i]] << "\n";
    }
}