树状数组二分
Contents
这篇 blog 记录一个树状数组上二分的小技巧。
介绍
对于树状数组,我们知道它支持 $O(\log n)$ 内进行如下操作:
- 单点修改
- 前缀和查询
现在我们希望它用 $O(\log n)$ 支持第三种操作:
- 二分找到一个前缀和的值
形式化的,例如在上面 lower_bound(v)
,就希望找到一个最小的 $i$,使得 $sum[1…i] \geq v$。
思想
我们利用倍增的思想,类似于树上倍增找 LCA,我们知道,因为树状数组是以 $2^i$ 来记录前缀和的,我们也可以从高位开始枚举,如果当前的和加上这一段的和 $< v$ 那么就加上,并且更新 pos
,否则看下一位。
// 找到最小的 p 使得 sum[1...p] >= v
int LOGN = 19;
int bit_lowerbound(int v) {
int sum = 0;
int pos = 0;
for(int i = LOGN; i >= 0; i--) {
if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] < v) {
sum += tr[pos + (1 << i)];
pos += (1 << i);
}
}
return pos + 1; // +1 because 'pos' will have position of largest value < 'v'
}
// 找到最小的 p 使得 sum[1...p] > v
int bit_upperbound(int v) {
int sum = 0;
int pos = 0;
for(int i = LOGN; i >= 0; i--) {
if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] <= v) {
sum += tr[pos + (1 << i)];
pos += (1 << i);
}
}
return pos + 1; // +1 because 'pos' will have position of largest value <= 'v'
}
由于树状数组的特性,pos
需要从 $0$ 开始。
线段树版本
其实线段树也可以 $O(\log n)$ 进行二分前缀和。
不过前缀和的方式不太一样,线段树的方式是从根开始 dfs,判断左边的和加上当前和是否 $< v$,如果是的话,加上左边的和,然后走到右边,否则就走左边。
这个思想有点类似于主席树的区间第 $k$ 大,只不过是线段树版本,更简单一些。
例题
例1 CF787E. Till I Collapse
题意
给定一个长度为 $n$ 的数组 $a_1,a_2,…,a_n$。
对于每一个 $k \in [1,n]$,问数组最少需要分成连续的几段,使得每一段里面 distinct 元素的数量 $\leq k$。
其中,$n \in [1, 10^5], a_i \in [1, n]$。
题解
先思考简化版的问题:
给定一个数组,$q$ 个询问,每次询问 $[L,R]$ 之间有多少个 distinct 的元素?
这个问题已经记录过了:HH的项链,大体思路就是离线,然后遍历询问的过程中只保留最靠左/右的那个,建立线段树/树状数组维护一个 $01$ 序列代表这个元素是否在当前范围内即可。
回到这个问题,对于每一个 $k$,可以知道我们从 $1$ 出发,每次向右走尽可能多的步数使得范围内 distinct 元素的数量 $\leq k$。能走到哪?可以考虑使用二分。
因为询问 $[L,R]$ 之间的 distinct 元素个数本质上就是 $sum[L…R] \leq k$,其实就是 $s_R - s_{L-1} \leq k$,也就是 $s_R \leq k + s_{L-1}$。
由于我们询问时固定了 $L$,所以 $k + s_{L-1}$ 是定值,那么 $R$ 的位置可以利用树状数组上 upper_bound 来 $O(\log n)$ 找到。
最后,需要考虑怎么离线询问?
注意到如果我们一边处理询问,一边获得下一个区间的左端点,那么区间的左端点是自然排序好的,不用手动去排序了。
实现的过程中,对于每一个左端点都开一个 vector
记录询问即可。
复杂度:$O(n \log n)$。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
int n, a[maxn], nxt[maxn], pos[maxn];
inline int lowbit(int x) { return x & -x; }
int tr[maxn];
void update(int p, int val) {
while (p <= n) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
int query(int p) {
int ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
int LOGN = 19;
// 找到最小的 p 使得 sum[1...p] >= v
int bit_lowerbound(int v) {
int sum = 0;
int pos = 0;
for(int i = LOGN; i >= 0; i--) {
if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] < v) {
sum += tr[pos + (1 << i)];
pos += (1 << i);
}
}
return pos + 1; // +1 because 'pos' will have position of largest value < 'v'
}
// 找到最小的 p 使得 sum[1...p] > v
int bit_upperbound(int v) {
int sum = 0;
int pos = 0;
for(int i = LOGN; i >= 0; i--) {
if (pos + (1 << i) <= n && sum + tr[pos + (1 << i)] <= v) {
sum += tr[pos + (1 << i)];
pos += (1 << i);
}
}
return pos + 1; // +1 because 'pos' will have position of largest value <= 'v'
}
vector<int> queries[maxn]; // query[i]: 以 i 为端点的所有 query 对应的 k值
int ans[maxn];
int main() {
fastio;
cin >> n;
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = n; i >= 1; i--) {
if (!pos[a[i]]) nxt[i] = n+1;
else nxt[i] = pos[a[i]];
pos[a[i]] = i;
}
// 先保留 [1,n] 中最靠左的
for (int i = n; i >= 1; i--) {
if (nxt[i] <= n) {
update(nxt[i], -1);
}
update(i, 1);
}
for (int k = 1; k <= n; k++) queries[1].push_back(k);
for (int L = 1; L <= n; L++) {
for (int k : queries[L]) {
ans[k]++;
int R = bit_upperbound(k + query(L-1)) - 1;
queries[R+1].push_back(k);
}
update(L, -1);
if (nxt[L] <= n) update(nxt[L], 1);
}
for (int i = 1; i <= n; i++) cout << ans[i] << " ";
cout << "\n";
}