这篇 blog 记录一个树状数组上二分的小技巧。

介绍

对于树状数组,我们知道它支持 $O(\log n)$ 内进行如下操作:

  1. 单点修改
  2. 前缀和查询

现在我们希望它用 $O(\log n)$ 支持第三种操作:

  1. 二分找到一个前缀和的值

形式化的,例如在上面 lower_bound(v),就希望找到一个最小的 $i$,使得 $sum[1…i] \geq v$。

思想

我们利用倍增的思想,类似于树上倍增找 LCA,我们知道,因为树状数组是以 $2^i$ 来记录前缀和的,我们也可以从高位开始枚举,如果当前的和加上这一段的和 $< v$ 那么就加上,并且更新 pos,否则看下一位。

// 找到最小的 p 使得 sum[1...p] >= v
int LOGN = 19;
int bit_lowerbound(int v) {
    int sum = 0;
    int pos = 0;

    for(int i = LOGN; i >= 0; i--) {
        if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] < v) {
            sum += tr[pos + (1 << i)];
            pos += (1 << i);
        }
    }

    return pos + 1; // +1 because 'pos' will have position of largest value < 'v'
}

// 找到最小的 p 使得 sum[1...p] > v
int bit_upperbound(int v) {
    int sum = 0;
    int pos = 0;

    for(int i = LOGN; i >= 0; i--) {
        if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] <= v) {
            sum += tr[pos + (1 << i)];
            pos += (1 << i);
        }
    }

    return pos + 1; // +1 because 'pos' will have position of largest value <= 'v'
}

由于树状数组的特性,pos 需要从 $0$ 开始。

线段树版本

其实线段树也可以 $O(\log n)$ 进行二分前缀和。

不过前缀和的方式不太一样,线段树的方式是从根开始 dfs,判断左边的和加上当前和是否 $< v$,如果是的话,加上左边的和,然后走到右边,否则就走左边。

这个思想有点类似于主席树的区间第 $k$ 大,只不过是线段树版本,更简单一些。

例题

例1 CF787E. Till I Collapse

题意

给定一个长度为 $n$ 的数组 $a_1,a_2,…,a_n$。

对于每一个 $k \in [1,n]$,问数组最少需要分成连续的几段,使得每一段里面 distinct 元素的数量 $\leq k$。

其中,$n \in [1, 10^5], a_i \in [1, n]$。

题解

先思考简化版的问题:

给定一个数组,$q$ 个询问,每次询问 $[L,R]$ 之间有多少个 distinct 的元素?

这个问题已经记录过了:HH的项链,大体思路就是离线,然后遍历询问的过程中只保留最靠左/右的那个,建立线段树/树状数组维护一个 $01$ 序列代表这个元素是否在当前范围内即可。


回到这个问题,对于每一个 $k$,可以知道我们从 $1$ 出发,每次向右走尽可能多的步数使得范围内 distinct 元素的数量 $\leq k$。能走到哪?可以考虑使用二分。

因为询问 $[L,R]$ 之间的 distinct 元素个数本质上就是 $sum[L…R] \leq k$,其实就是 $s_R - s_{L-1} \leq k$,也就是 $s_R \leq k + s_{L-1}$。

由于我们询问时固定了 $L$,所以 $k + s_{L-1}$ 是定值,那么 $R$ 的位置可以利用树状数组上 upper_bound 来 $O(\log n)$ 找到。


最后,需要考虑怎么离线询问?

注意到如果我们一边处理询问,一边获得下一个区间的左端点,那么区间的左端点是自然排序好的,不用手动去排序了。

实现的过程中,对于每一个左端点都开一个 vector 记录询问即可。

复杂度:$O(n \log n)$。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;

int n, a[maxn], nxt[maxn], pos[maxn];
inline int lowbit(int x) { return x & -x; }
int tr[maxn];
void update(int p, int val) {
    while (p <= n) {
        tr[p] += val;
        p += lowbit(p);
    }
}
// return sum[1...p]
int query(int p) {
    int ans = 0;
    while (p > 0) {
        ans += tr[p];
        p -= lowbit(p);
    }
    return ans;
}
int LOGN = 19;

// 找到最小的 p 使得 sum[1...p] >= v
int bit_lowerbound(int v) {
    int sum = 0;
    int pos = 0;

    for(int i = LOGN; i >= 0; i--) {
        if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] < v) {
            sum += tr[pos + (1 << i)];
            pos += (1 << i);
        }
    }

    return pos + 1; // +1 because 'pos' will have position of largest value < 'v'
}

// 找到最小的 p 使得 sum[1...p] > v
int bit_upperbound(int v) {
    int sum = 0;
    int pos = 0;

    for(int i = LOGN; i >= 0; i--) {
        if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] <= v) {
            sum += tr[pos + (1 << i)];
            pos += (1 << i);
        }
    }

    return pos + 1; // +1 because 'pos' will have position of largest value <= 'v'
}


vector<int> queries[maxn];  // query[i]: 以 i 为端点的所有 query 对应的 k值
int ans[maxn];
int main() {
    fastio;
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> a[i];
    for (int i = n; i >= 1; i--) {
        if (!pos[a[i]]) nxt[i] = n+1;
        else nxt[i] = pos[a[i]];
        pos[a[i]] = i;
    }

    // 先保留 [1,n] 中最靠左的
    for (int i = n; i >= 1; i--) {
        if (nxt[i] <= n) {
            update(nxt[i], -1);
        } 
        update(i, 1);
    }

    for (int k = 1; k <= n; k++) queries[1].push_back(k);

    for (int L = 1; L <= n; L++) {
        for (int k : queries[L]) {
            ans[k]++;
            int R = bit_upperbound(k + query(L-1)) - 1;
            queries[R+1].push_back(k);
        }
        update(L, -1);
        if (nxt[L] <= n) update(nxt[L], 1);
    }
    for (int i = 1; i <= n; i++) cout << ans[i] << " ";
    cout << "\n";
}

参考链接

  1. https://codeforces.com/blog/entry/61364