线段树/树状数组/分块 例题
Contents
主要记录一些遇到的线段树/分块例题。
例1 CF438D
题意
给定 $N$ 个正整数和 $M$ 个询问,询问有 3 种:
$1 ~ l ~ r$:输出 $\sum\limits_{i=l}^r a_i$
$2 ~ l ~ r ~ x$:将 $a_l$ 到 $a_r$ 的所有数取 $\text{mod } x$
$3 ~ k ~ x$:将 $a_k = x$
其中,$1 \leq N,M \leq 10^5, 1 \leq a_i,x \leq 10^9$
题解
本题分块和线段树都可以做,我们这里用 线段树 来做。
主要是需要考虑 区间取模 怎么办?
回忆一下分块例题中的 区间开方,我们维护了一个额外的tag表示这个区间是否为 全0/全1,如果不是 全0/全1 就暴力开方。
取模操作同理,我们发现,如果 $a_i > x$,那么 $a_i \text{ mod } x \leq \frac{a_i}{2}$,所以对于每个 $a_i$,最多只会被 $\text{mod}$ $\log (a_i)$ 次!
所以,我们维护一个 区间最大值,取模时,检查一下 区间最大值是否大于 $x$:
- 如果大于 $x$,就继续往下递归。
- 如果小于 $x$,就直接返回。
base case 就是区间长度为 $1$ 时,直接对这个元素开方即可。
例2 CF558E
题意
给定一个长度为 $n$ 的string,仅包含小写字母。给 $q$ 个询问:
$l,r,k$:将string的 $[l,r]$ 进行sort,如果 $k=1$ 就升序,$k=0$ 降序。
输出所有询问结束后的string。
其中,$1\leq n \leq 10^5, 1 \leq q\leq 50000$
题解
线段树来处理。
首先,string只包含小写字母。所以每个 node
可以维护一个 cnt[26]
代表这个node里的每个字母出现的次数。
其次,对于排序,我们在每个 node
中维护一个标记 $k$ 来代表该区间是否排序好了。若 $k=0$ 代表降序,$k=1$ 代表升序,$k = -1$ 代表乱序。
最后,维护一个 $lazy$ 标记,我们会注意到对于一个node
而言,若 $lazy = 1$,那么这个 node
必然是排序好了的!(要么 $k=0$ ,要么 $k=1$)。
有了以上信息,我们就可以进行 sort
操作了!
sort
$[L,R]$ 的时候,步骤如下:
- 提取出 $[L,R]$ 内每个字母出现的次数。
- 求出 $[L,R]$ 与 $[l,mid]$(当前node 左child的范围)的区间交集 $[l_1,r_1]$
- 求出 $[L,R]$ 与 $[mid+1, r]$(当前node 右child的范围)的区间交集 $[l_2,r_2]$
- 用指针遍历 $a-z$(或者 $z-a$),根据升序/降序 将 字母出现的次数分别填充 到 左child和右child的
cnt[]
中。(注意,这里的填充是指:先填充进一个int* buf = new int[26];
的动态数组,然后将buf[]
作为参数,再往下传递,直到区间完全覆盖,再将buf[]
的内容复制进cnt[]
里)。
代码
using namespace std;
#include <bits/stdc++.h>
const int maxn = 1e5+5;
const int maxm = 2e5+10;
int n,q;
char arr[maxn];
struct Node {
int l,r,k,cnt[26];
bool lazy = 0;
} tr[4*maxn];
int tmp[26];
inline int len(int cur) { return tr[cur].r - tr[cur].l + 1; }
void push_up(int cur) {
int lc = cur<<1, rc = lc+1;
for (int i = 0; i < 26; i++) {
tr[cur].cnt[i] = tr[lc].cnt[i] + tr[rc].cnt[i];
}
if (tr[lc].k != -1 && tr[lc].k == tr[rc].k) tr[cur].k = tr[lc].k; // k = 1: increasing, k = 0: decreasing
else tr[cur].k = -1;
}
void put(int cur, int k) {
int lc = cur<<1, rc = lc+1;
memset(tr[lc].cnt, 0, sizeof(tr[lc].cnt));
memset(tr[rc].cnt, 0, sizeof(tr[rc].cnt));
memcpy(tmp, tr[cur].cnt, sizeof(tmp));
int lsz = len(lc), rsz = len(rc);
if (k) {
int p = 0;
while (p < 26 && lsz) {
int delta = min(lsz, tmp[p]);
tr[lc].cnt[p] += delta;
lsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p++;
}
while (p < 26 && rsz) {
int delta = min(rsz, tmp[p]);
tr[rc].cnt[p] += delta;
rsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p++;
}
} else {
int p = 25;
while (p >= 0 && lsz) {
int delta = min(lsz, tmp[p]);
tr[lc].cnt[p] += delta;
lsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p--;
}
while (p >= 0 && rsz) {
int delta = min(rsz, tmp[p]);
tr[rc].cnt[p] += delta;
rsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p--;
}
}
assert(lsz == 0);
assert(rsz == 0);
}
void push_down(int cur) {
if (!tr[cur].lazy) return;
int lc = cur<<1, rc = lc+1;
tr[cur].lazy = 0;
tr[lc].lazy = tr[rc].lazy = 1;
assert(tr[cur].k != -1);
int k = tr[cur].k;
tr[lc].k = k;
tr[rc].k = k;
put(cur,k);
}
void build(int cur, int L, int R) {
tr[cur].l = L, tr[cur].r = R;
if (L == R) {
memset(tr[cur].cnt, 0, sizeof(tr[cur].cnt));
tr[cur].cnt[arr[L]-'a'] = 1;
return;
}
int mid = (L+R) >> 1;
build(cur<<1, L, mid);
build(cur<<1|1, mid+1, R);
push_up(cur);
}
int ress[26]; // 每次query的结果会存到这里
void clear(int* buf) {
for (int i = 0; i < 26; i++) buf[i] = 0;
}
int inter(int l1, int r1, int l2, int r2) { // 求区间交集的长度
int l = max(l1,l2), r = min(r1,r2);
return max(0,r-l+1);
}
void update(int cur, int L, int R, int k, int* res) { // 注意参数里有个动态数组 res
int lc = cur<<1, rc = lc+1;
int l = tr[cur].l, r = tr[cur].r;
if (l >= L && r <= R) {
tr[cur].k = k;
tr[cur].lazy = 1;
for (int i = 0; i < 26; i++) tr[cur].cnt[i] = res[i]; // 区间完全覆盖,复制到 cnt 中。
clear(res); // 记得清空,之后可能还要用
return;
}
int mid = (l+r) >> 1;
int lsz = inter(l,mid,L,R), rsz = inter(mid+1,r,L,R);
int* buf = new int[26]; // 这里采用了动态数组
for (int i = 0; i < 26; i++) buf[i] = 0; //注意new出来的需要先清空一下,另外不能使用 memset(因为是指针)
if (k) {
int p = 0;
while (p < 26 && lsz) {
int delta = min(lsz, res[p]);
buf[p] += delta;
lsz -= delta;
res[p] -= delta;
if (!res[p]) p++;
}
if (L <= mid) {
update(lc, L, R, k, buf); // 传递 buf
}
while (p < 26 && rsz) {
int delta = min(rsz, res[p]);
buf[p] += delta;
rsz -= delta;
res[p] -= delta;
if (!res[p]) p++;
}
if (R > mid) {
update(rc, L, R, k, buf); // 传递 buf
}
} else {
int p = 25;
while (p >= 0 && lsz) {
int delta = min(lsz, res[p]);
buf[p] += delta;
lsz -= delta;
res[p] -= delta;
if (!res[p]) p--;
}
if (L <= mid) {
update(lc, L, R, k, buf); // 传递 buf
}
while (p >= 0 && rsz) {
int delta = min(rsz, res[p]);
buf[p] += delta;
rsz -= delta;
res[p] -= delta;
if (!res[p]) p--;
}
if (R > mid) {
update(rc, L, R, k, buf); // 传递 buf
}
}
delete[] buf;
assert(lsz == 0);
assert(rsz == 0);
push_up(cur);
}
void query(int cur, int L, int R) { // 求区间内每个字母出现的个数
int l = tr[cur].l, r = tr[cur].r;
if (l >= L && r <= R) {
for (int i = 0; i < 26; i++) ress[i] += tr[cur].cnt[i];
return;
}
int lc = cur<<1, rc = lc+1;
push_down(cur);
int mid = (l+r) >> 1;
if (L <= mid) query(lc, L, R);
if (R > mid) query(rc, L, R);
push_up(cur);
}
void printans() {
for (int i = 1; i <= n; i++) {
memset(ress, 0, sizeof(ress));
query(1,i,i);
for (int j = 0; j < 26; j++) {
if (ress[j]) {
printf("%c",(char)('a'+j));
ress[j]--;
break;
}
}
}
printf("\n");
}
int main() {
scanf("%d%d",&n,&q);
scanf("%s", arr+1);
build(1, 1, n);
while (q--) {
int l,r,k; scanf("%d%d%d",&l,&r,&k);
memset(ress, 0, sizeof(ress));
query(1,l,r);
update(1,l,r,k,ress);
}
printans();
}
例3 洛谷P1972 HH的项链
题意
给定一个长度为 $N$ 的数组 $a_1,a_2,…,a_n$,以及 $m$ 个询问,每次询问 $[L,R]$ 之间有多少个不同的数。
其中,$1 \leq n,m,a_i \leq 10^6$
题解
我们可以发现,如果我们固定了询问的右端点 $R$,那么无论 $L$ 为多少,在 $R$ 的左侧的所有重复数字中,仅保留最靠右的一个 copy 即可。
例如 $arr = 1,3,2,1,7,1$,那么我们在遍历到 $i = 4$ 时,我们仅需要保留最后一个 $1$ (也就是 index 为 $4$ 的数字)。
由上,在处理 区间内不同的数 时,一个常见的套路是:
- 离线处理所有询问,按右端点 $R$ 排序。
- 从左到右遍历数组,遍历到 $i$ 时,对于所有 $a_i$ 的 copy,仅保留最靠右的那一个(也就是 $i$),之前的所有 copy 全部删除。
- 回答所有 $[L, i]$ 的询问。
那么对于本题,第一步是离线处理询问,按右端点 $R$ 排序。
第二步是遍历数组,遍历过程中维护一个 int pos[]
数组,其中 pos[val]
代表:在 $arr[1…i]$ 中,值为 val
的数 最靠右的 index。
当我们遍历到 $i$ 时,令 int val = arr[i]
,将 pos[val]
处的数字 删掉(更新线段树),然后将 $i$ 处的数字 加入线段树,最后更新一下 pos[val] = i;
•本题中,删掉就是将 线段树的位置 $i$ 的值 减去$1$,加上就是将 线段树的位置 $i$ 的值 加上$1$。
然后回答所有 以 $i$ 为右端点的询问(求 $[L,i]$ 的和即可)。
注:如果询问 在线 怎么办?可以使用主席树!(对于
pr[i]
建主席树)具体做法参考:CF1422F Boring Queries
代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e6+5;
int n, m, arr[maxn];
struct node {
int sum;
} tr[maxn<<2];
void push_up(int cur) {
tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}
void update(int cur, int l, int r, int p, int x) {
if (l == r) {
tr[cur].sum += x; return;
}
int mid = (l+r) >> 1;
if (p <= mid) update(cur<<1, l, mid, p, x);
else update(cur<<1|1, mid+1, r, p, x);
push_up(cur);
}
int query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return tr[cur].sum;
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += query(cur<<1, l, mid, L, R);
if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
return res;
}
int ans[maxn];
int pos[maxn];
struct Query {
int id,l,r;
} q[maxn];
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> arr[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l,r;
cin >> l >> r;
q[i] = {i,l,r};
}
sort(q+1,q+m+1, [](auto a, auto b) {
return a.r < b.r; // 根据右端点离线
});
int ptr = 1;
for (int i = 1; i <= n; i++) {
int val = arr[i];
if (pos[val]) update(1, 1, n, pos[val], -1); // 删去 pos[val]
update(1, 1, n, i, 1); // 加上 i
pos[val] = i; // 更新 pos[val]
while (ptr <= m && q[ptr].r == i) { // 回答所有以 i 作为右端点的询问
int id = q[ptr].id, L = q[ptr].l, R = q[ptr].r;
ans[id] = query(1, 1, n, L, R);
ptr++;
}
if (ptr > m) break;
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}
例4 CF1000F One Occurrence
题意
给定一个长度为 $n$ 的数组,$m$ 个询问,每次询问一个区间 $[l,r]$,如果这个区间里存在只出现一次的数,输出这个数(如果有多个就输出任意一个),没有就输出 $0$。
其中 $n,m \leq 5 \times 10^5, 1 \leq a_i \leq 5 \times 10^5$。
题解
和例3类似的套路,也是离线处理询问(根据右端点sort),从左往右遍历,仅保留最靠右的复制。
问题在于怎么删除 和 加入数字?因为本题不再是求数量了,所以不能简单的加 $1$ 或者 减 $1$。
会发现,我们仅关心一个区间内是否存在 unique 的数字,对于一个询问 $[L,R]$ 内,我们只要看,是否存在 $i$ 使得 $arr[i]$ 的前一个复制 不在 $[L,R]$ 内。(也就是说,pos[val]
是否小于 $L$)
那么,我们用线段树维护一下 区间最小值 即可,其中 $[L,R]$ 的区间最小值就代表着所有$i \in [L,R]$ 中,pos[arr[i]]
的最小值。如果一个区间 $[L,R]$ 的最小值 $\geq L$,那么答案不存在,否则答案存在。
那么,删除位置 $i$ 就是将它这个位置上的值设为 $inf$。
加入位置 $i$ 的数,就是将它这个位置上的值设为 pos[arr[i]]
。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
struct node {
int pre = 1e9, idx = 0; // pre 代表 pos[val] 的值,idx代表这个最小值对应的 index
} tr[maxn<<2];
int pos[maxn], n, m, arr[maxn];
struct Query {
int id,l,r;
} q[maxn];
void push_up(int cur) {
int lpre = tr[cur<<1].pre, rpre = tr[cur<<1|1].pre;
if (lpre < rpre) tr[cur].pre = lpre, tr[cur].idx = tr[cur<<1].idx;
else tr[cur].pre = rpre, tr[cur].idx = tr[cur<<1|1].idx;
}
void update(int cur, int l, int r, int p, int x) {
if (l == r) {
tr[cur].pre = x, tr[cur].idx = l;
return;
}
int mid = (l+r) >> 1;
if (p <= mid) update(cur<<1, l, mid, p, x);
else update(cur<<1|1, mid+1, r, p, x);
push_up(cur);
}
pii query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return {tr[cur].pre, tr[cur].idx};
int mid = (l+r) >> 1;
pii r1 = {1e9, 0}, r2 = {1e9, 0};
if (L <= mid) r1 = query(cur<<1, l, mid, L, R);
if (R > mid) r2 = query(cur<<1|1, mid+1, r, L, R);
if (r1.first > r2.first) return r2;
else return r1;
}
int ans[maxn];
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> arr[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l,r; cin >> l >> r;
q[i] = {i,l,r};
}
sort(q+1, q+m+1, [](auto a, auto b) {
return a.r < b.r;
});
int ptr = 1;
for (int i = 1; i <= n; i++) {
int val = arr[i];
update(1, 1, n, i, pos[val]); // 在i处 加入数字 pos[val]
if (pos[val]) update(1, 1, n, pos[val], 1e9); // 在pos[val] 处删除数字(设为 inf)
pos[val] = i;
while (ptr <= m && q[ptr].r == i) {
int L = q[ptr].l, R = q[ptr].r;
pii res = query(1, 1, n, L, R);
int id = q[ptr].id;
if (res.first >= L) ans[id] = 0; // 如果这个区间内,所有 pos[val] 都 >= L
else ans[id] = arr[res.second]; // 否则,答案存在
ptr++;
}
if (ptr > m) break;
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}
例5 CF803G Periodic RMQ Problem
题意
给定一个长度为 $n$ 的数组,将这个数组复制为 $k$ 份并且拼接在一起。
然后回答 $Q$ 个询问,分两种询问:(所有询问都在拼接后的数组上进行)
$1 ~ l ~ r ~ x$:将 $[l,r]$ 中的所有元素改为 $x$。
$2 ~ l ~ r$:询问 $[l,r]$ 中的最小值。
其中,$1 \leq n \leq 10^5, 1 \leq k \leq 10^4, 1 \leq Q \leq 10^5, ~l,r \in [1,n]$
题解
注意到数组的总长度可以达到 $10^9$。但是询问只有 $Q=10^5$,我们于是想到了动态开点线段树。
但是注意到,这个数组是有初始值的,按理说应该把树建好,不能动态开点。
所以我们考虑 不建树,而是在开点的时候,把这个点的初始状态处理好(就是在没有任何修改的情况下,这个点的初始状态),然后再对这个点进行操作。
所以,对于一个点对应的区间 $[l,r]$,怎么处理初始状态?
注意到,由于数组是一个循环节,所以可以分以下三种情况:
- $[l,r]$ 的长度 $\geq n$,它的初始最小值就是数组 $[1,n]$ 的最小值。
- $[l,r]$ 的长度 $< n$,且属于同一个循环节,那么用 ST表 预处理一下 $[l \text{ mod } n,r\text{ mod } n]$ 的最小值即可。
- $[l,r]$ 的长度 $< n$,且不属于同一个循环节,那么它的初始最小值就是 $[l \text{ mod } n, n] \bigcup [1, r\text{ mod } n]$
开点的函数是代码里的 build()
,注意到只要我们访问到了 cur
,并且需要访问它的任意一个 child 时,需要把左右两个子树都开点,这保证了 pushup()
的正确性。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e7+5;
const int maxm = 1e5+5;
int n,k,Q;
int a[maxm];
int belong(int x) {
return (x-1) / n + 1;
}
struct Node {
int lc, rc, lazy, val;
} tr[maxn];
int id = 0;
int st[maxm][18];
int bin[maxm];
void build_st() {
bin[1] = 0, bin[2] = 1;
for (int i = 3; i < maxm; i++) bin[i] = bin[i>>1] + 1;
for (int i = 1; i <= n; i++) {
st[i][0] = a[i];
}
for (int j = 1; j < 18; j++) {
for (int i = 1; i + (1<<j) - 1 <= n; i++) {
st[i][j] = min(st[i][j-1], st[i+(1<<(j-1))][j-1]);
}
}
}
int ask_st(int l, int r) {
if (l > r) swap(l,r);
int j = bin[r-l+1];
return min(st[l][j], st[r-(1<<j)+1][j]);
}
void build(int& cur, int l, int r) {
if (cur) return;
cur = ++id;
if (r-l+1 >= n) {
tr[cur].val = ask_st(1, n);
return;
}
int bl = belong(l), br = belong(r);
l %= n, r %= n;
if (!l) l = n;
if (!r) r = n;
if (bl == br) {
tr[cur].val = ask_st(l,r);
return;
}
// bl != br
tr[cur].val = min(ask_st(l, n), ask_st(1, r));
}
void push_down(int cur) {
if (!tr[cur].lazy) return;
int lazy = tr[cur].lazy;
tr[cur].lazy = 0;
int lc = tr[cur].lc, rc = tr[cur].rc;
tr[lc].lazy = tr[lc].val = tr[rc].lazy = tr[rc].val = lazy;
}
void push_up(int cur) {
tr[cur].val = min(tr[tr[cur].lc].val, tr[tr[cur].rc].val);
}
void update(int cur, int l, int r, int L, int R, int x) {
if (l >= L && r <= R) {
tr[cur].lazy = tr[cur].val = x;
return;
}
int mid = (l+r) >> 1;
build(tr[cur].lc, l, mid);
build(tr[cur].rc, mid+1, r);
push_down(cur);
if (L <= mid) update(tr[cur].lc, l, mid, L, R, x);
if (R > mid) update(tr[cur].rc, mid+1, r, L, R, x);
push_up(cur);
}
int query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return tr[cur].val;
int mid = (l+r) >> 1;
build(tr[cur].lc, l, mid);
build(tr[cur].rc, mid+1, r);
push_down(cur);
int lres = 1e9, rres = 1e9;
if (L <= mid) lres = query(tr[cur].lc, l, mid, L, R);
if (R > mid) rres = query(tr[cur].rc, mid+1, r, L, R);
push_up(cur);
return min(lres, rres);
}
int rt = 0;
int main() {
fastio;
cin >> n >> k;
for (int i = 1; i <= n; i++) cin >> a[i];
build_st();
cin >> Q;
build(rt, 1, n*k);
for (int i = 1; i <= Q; i++) {
int op,l,r,x;
cin >> op;
if (op == 1) {
cin >> l >> r >> x;
update(1, 1, n*k, l, r, x);
} else {
cin >> l >> r;
int res = query(1, 1, n*k, l, r);
cout << res << "\n";
}
}
}
例6 CF gym103687F. Easy Fix
题意
给定一个 permutation $p_1, p_2 … p_n$。
定义 $A_i$ 为 $i$ 左边比 $p_i$ 小的元素数量,$B_i$ 为 $i$ 右边比 $p_i$ 大的元素数量,$C_i = \min(A_i, B_i)$。
给定 $m$ 个询问,每次询问 $L, R$,回答在交换 $p_L, p_R$ 的位置以后,$\sum\limits_{i=1}^n C_i$ 的值?
其中,$n \leq 10^5, m \leq 2 \times 10^5$。
题解
先讨论一下交换后有哪些元素受到影响?
很明显,$L, R$ 位置会受到影响,$i \in (L,R)$ 且 $p_i \in (p_L, p_R)$ 的会受到影响,其他的均无影响。
我们先讨论 $i \in (L,R)$ 且 $p_i \in (p_L, p_R)$ 这种情况:
Case1: 如果 $a_u < a_v$,那么 $A_i - 1, B_i + 1$。
Case1.1: 如果 $A_i \leq B_i$,那么 $C_i - 1$。
Case1.2: 如果 $A_i > B_i + 1$,那么 $C_i + 1$。
Case2: 如果 $a_u > a_v$,那么 $A_i + 1, B_i - 1$。
Case2.1: 如果 $A_i \geq B_i$,那么 $C_i - 1$。
Case2.2: 如果 $A_i < B_i - 1$,那么 $C_i + 1$。
然后再讨论一下 $L, R$ 受到的影响:
只要先把 $C_L, C_R$ 的值从询问中减掉,然后重新找一下换了之后的 $C_L, C_R$ 即可。
• 需要注意当 $p_L > p_R$ 时,换了之后 $R$ 位置上的 $p_L$ 计算 $A_i$ 时会将 $p_R$ 忽略掉,所以要加 $1$。
好的,现在该讨论一下怎么获得 $A_i, B_i, C_i$。
首先,获得 $A_i$ 可以直接用树状数组维护 $p_i$(作为index),树状数组上的值为 0 和 1。遍历一下就可以求出初始状态的 $A_i$ 了。
然后 $B_i = p_i - 1 - A_i$,很好理解,这样 $C_i$ 也有了。
再讨论一下怎么处理询问?
我们将 Case1, Case2 分开处理:如果我们开一个线段树/树状数组,index 为 $p_i$,值为 $-1, 0, 1$ 来记录 Case1.1, Case1.2 的情况,那么:
每次询问就是询问一个区间 $(L,R)$ 内,满足 $(p_L, p_R)$ 的值的一些东西的和。
似乎是 主席树?
确实,不过本题询问离线,Stupidcdd 教导我:
能用主席树做的,离线情况下就能用线段树/树状数组。
回忆一下主席树的作用:对于每个区间 $[L,R]$,维护一个权值线段树,原理其实是每个点维护一个线段树,$R$ 上的线段树减去 $L-1$ 位置的线段树而已。
那么询问离线的时候,我们只有一个树状数组,可以在遍历的过程中,获得前缀 $[1…i]$ 对应的树状数组,但是我们没有持久化,所以保存不下来。
因为树状数组的历史版本保存不下来,不妨从询问下手?
离线的一个常见套路:
对于询问的端点 $L,R$,每个端点开一个 vector 储存所有的询问,遍历到这个点时处理所有对应的询问,加到答案上。
这个题就是这样,我们想要知道 $(L,R)$ 内满足 $(p_L, p_R)$ 的值的一些东西的和,那么不妨在遍历到 $L$ 的时候询问一下 $(p_L, p_R)$,在遍历到 $R-1$ 的时候询问一下 $(p_L, p_R)$,然后一减,不就是答案了么?
struct Query_Info {
int pl, pr; // 询问 [pl, pr] 的部分
int f; // 符号
int idx; // 询问的 idx
int type; // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn]; // qinfo[i] 端点为 i 的需要询问
struct Single_Query {
int x; // 询问 < x 有多少个
int idx; // 询问的 idx
int diff; // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn]; // 单点询问 < x 有多少个
ll ans[maxn];
...
...
...
// (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
for (int j = 1; j <= m; j++) {
auto q = queries[j];
int u = q.first, v = q.second;
if (u != v) {
int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
int type = (p[u] < p[v] ? 1 : 2);
qinfo[u+1].push_back({mn, mx, -1, j, type}); // 由于处理时,还没加上,所以全部应该 + 1 (u -> u+1, v-1 -> v)
qinfo[v].push_back({mn, mx, 1, j, type});
ans[j] -= (C[u] + C[v]); // 先去掉 C[u] + C[v] 的影响,后面直接加上
single[u].push_back({p[v], j, 0});
single[v].push_back({p[u], j, p[u] > p[v]});
}
}
for (int i = 1; i <= n; i++) {
for (auto q : qinfo[i]) {
if (q.type == 1) {
ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
} else {
ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
}
}
for (auto q : single[i]) {
int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
ans[q.idx] += min(AA, BB);
}
if (A[i] <= B[i]) tr1.update(p[i], -1);
if (A[i] > B[i] + 1) tr1.update(p[i], 1);
if (A[i] >= B[i]) tr2.update(p[i], -1);
if (A[i] < B[i] - 1) tr2.update(p[i], 1);
tr3.update(p[i], 1);
}
在这个题中:
- $(L,R)$ 我们通过离线询问,答案相减来实现。
- $(p_L, p_R)$ 我们通过权值线段树/权值树状数组来实现。
- 树状数组上面维护了 Case1.1, 1.2 (或者 2.1, 2.2) 的情况。
总结一下思路:
离线的关键在于从左向右遍历的时候,获得每个前缀对应的树状数组。
在离线 $(L,R)$ 时,将询问的信息扔给每个端点的 vector,在遍历到的时候再处理,处理时直接对 ans[]
进行操作。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+55;
const int maxm = 2e5+5;
int n, p[maxn], m;
pii queries[maxm];
struct BIT {
int n, tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= n) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
int query(int p) {
int ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr, tr1, tr2, tr3;
struct Query_Info {
int pl, pr; // 询问 [pl, pr] 的部分
int f; // 符号
int idx; // 询问的 idx
int type; // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn]; // qinfo[i] 端点为 i 的需要询问
struct Single_Query {
int x; // 询问 < x 有多少个
int idx; // 询问的 idx
int diff; // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn]; // 单点询问 < x 有多少个
ll ans[maxn];
ll csum = 0;
int main() {
fastio;
cin >> n;
tr.n = tr1.n = tr2.n = tr3.n = n;
for (int i = 1; i <= n; i++) cin >> p[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l, r; cin >> l >> r;
if (l >= r) swap(l, r);
queries[i] = {l, r};
}
for (int i = 1; i <= n; i++) {
A[i] = tr.query(p[i] - 1); // 左边比它小
B[i] = p[i] - 1 - A[i]; // 右边比它小
C[i] = min(A[i], B[i]);
csum += C[i];
tr.update(p[i], 1);
}
for (int i = 1; i <= m; i++) ans[i] = csum;
// (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
for (int j = 1; j <= m; j++) {
auto q = queries[j];
int u = q.first, v = q.second;
if (u != v) {
int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
int type = (p[u] < p[v] ? 1 : 2);
qinfo[u+1].push_back({mn, mx, -1, j, type}); // 由于处理时,还没加上,所以全部应该 + 1 (u -> u+1, v-1 -> v)
qinfo[v].push_back({mn, mx, 1, j, type});
ans[j] -= (C[u] + C[v]); // 先去掉 C[u] + C[v] 的影响,后面直接加上
single[u].push_back({p[v], j, 0});
single[v].push_back({p[u], j, p[u] > p[v]});
}
}
for (int i = 1; i <= n; i++) {
for (auto q : qinfo[i]) {
if (q.type == 1) {
ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
} else {
ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
}
}
for (auto q : single[i]) {
int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
ans[q.idx] += min(AA, BB);
}
if (A[i] <= B[i]) tr1.update(p[i], -1);
if (A[i] > B[i] + 1) tr1.update(p[i], 1);
if (A[i] >= B[i]) tr2.update(p[i], -1);
if (A[i] < B[i] - 1) tr2.update(p[i], 1);
tr3.update(p[i], 1);
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}
例7 Atcoder ABC237G. Range Sort Query
题意
给定一个长度为 $n$ 的permutation $p_1,p_2…,p_n$,和一个正整数 $x$。
给定 $q$ 个询问,询问有两种:
$1 ~ L ~ R$:将 $p[L…R]$ 从小到大排序。
$2 ~ L ~ R$:将 $p[L…R]$ 从大到小排序。
所有询问结束后,回答 $x$ 所在的index。
其中,$n,q \leq 2 \times 10^5, x \in [1,n]$。
题解
经典套路题。一般看见这种每次询问然后排序的,一般来讲只有在整个数组中 distinct 的数非常少的时候才能做。
这个题注意到我们只关心 $x$,于是我们可以把剩下的数分为比 $x$ 小和比 $x$ 大的。
而除了 $x$ 以外所有数之间的相对顺序我们并不关心。
所以我们可以直接开两棵线段树,线段树里只储存 $0$ 或者 $1$,代表这个位置是否存在比 $x$ 大/小 的数。
而 sort 的时候,我们就相当于把两棵线段树的 $[L,R]$ 中,看一下 $<x$ 和 $>x$ 分别有多少个,然后清空整个线段,然后重新进行区间赋值即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5+5;
const int maxm = 1e5+5;
struct Tree_Node {
int sum = 0;
bool flag = 0;
int val; // 区间赋的值
};
struct Segment_Tree {
Tree_Node tr[maxn<<2];
void push_up(int cur) {
tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}
void push_down(int cur, int l, int r) {
if (!tr[cur].flag) return;
int val = tr[cur].val;
tr[cur].flag = 0;
int mid = (l+r) >> 1;
int lc = cur<<1, rc = cur<<1|1;
int llen = (mid - l + 1), rlen = (r - mid);
tr[lc].flag = tr[rc].flag = 1;
tr[lc].val = tr[rc].val = val;
tr[lc].sum = llen * val, tr[rc].sum = rlen * val;
}
// 区间赋值为 x
void update(int cur, int l, int r, int L, int R, int x) {
if (L > R) return;
if (L <= l && R >= r) {
tr[cur].flag = 1;
tr[cur].val = x;
tr[cur].sum = (r-l+1) * x;
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update(cur<<1, l, mid, L, R, x);
if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
// 区间求和
int query(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].sum;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += query(cur<<1, l, mid, L, R);
if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
return res;
}
} tr1, tr2; // 比它小,比它大
int n, q, x, p;
int a[maxn];
int main() {
cin >> n >> q >> x;
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (a[i] == x) p = i;
}
for (int i = 1; i <= n; i++) {
if (a[i] < x) tr1.update(1, 1, n, i, i, 1);
if (a[i] > x) tr2.update(1, 1, n, i, i, 1);
}
while (q--) {
int c, L, R; cin >> c >> L >> R;
int sum1 = tr1.query(1, 1, n, L, R);
int sum2 = tr2.query(1, 1, n, L, R);
tr1.update(1, 1, n, L, R, 0);
tr2.update(1, 1, n, L, R, 0);
if (c == 1) {
tr1.update(1, 1, n, L, L+sum1-1, 1);
tr2.update(1, 1, n, R-sum2+1, R, 1);
if (p >= L && p <= R) p = L + sum1;
} else {
tr2.update(1, 1, n, L, L+sum2-1, 1);
tr1.update(1, 1, n, R-sum1+1, R, 1);
if (p >= L && p <= R) p = L + sum2;
}
}
cout << p << endl;
}
例8 Eerie Subarrays
题意
给定一个 $1$ 到 $n$ 的permutation,求有多少个 subarray 使得其最左边的元素是这个subarray的中位数?
其中,$n \leq 2 \times 10^5$。
题解
假设我们正在考虑的数是 $a_i$,那么只用考虑起点从 $i$ 开始的subarray。
将所有 $< a_i$ 的数设为 $-1$,所有 $> a_i$ 的数设为 $1$,那么问题转化为:
求 $R > i$ 的数量,使得 $sum[i…R] = 0$ (等价于使得 $sum[i-1] = sum[R]$)。
线段树没法统计这个,但可以用分块来统计!
同时注意到给定的是一个 $1$ 到 $n$ 的permutation,所以我们从 $1$ 开始,初始状态下所有值均为 $1$,而 $1$ 所在的位置为 $0$,之后考虑下一个数的时候就是将 $1$ 所在的位置设为 $-1$,$2$ 所在的位置设为 $0$,以此类推。
这样每次更新只需要更改两个位置的数,而更改一个位置的数会影响后面的所有的 $sum$ 值。对于块内的,暴力更新,对于后面的所有块,给块加上一个标记即可。所以每次更新复杂度是 $O(\sqrt n)$。
最后如何统计答案?对于块内,暴力统计,对于一个整块,我们维护块内每个元素出现的次数即可。
如何维护?显然不能对每个块都开 $1$ 到 $n$ 的数组。但注意到每个块里面最多 $B = \sqrt n$ 个元素,并且每个元素被暴力减的次数 $\leq 2B$。所以我们可以先将每个块的所有元素初始化到 $[2B, 3B]$ 的范围内(通过更新标记来初始化),然后每个块里面只要维护一个长度为 $3B$ 的vector即可,总空间复杂度就是 $O(n)$ 了。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
int n, p[maxn], pos[maxn];
vector<int> b[maxn];
int f[maxn], B;
// [1,B] -> vector<int>, [B+1, 2B] -> vector<int>
int L[maxn], R[maxn];
int bel[maxn];
int bmax = 0;
vector<int> cnt[maxn]; // cnt[bi][i]: 代表 i 在 block(bi) 中的出现次数
// 将位置 x 的值减 1
void upd(int x) {
int bi = bel[x];
int l = L[bi];
int ori = b[bi][x-l];
b[bi][x - l] -= 1;
cnt[bi][ori]--;
cnt[bi][ori-1]++;
}
// 将位置 x 的值减 1, 说明所有 >= x 的值都要减 1
void updall(int x) {
int bi = bel[x];
int r = R[bi];
for (int j = x; j <= r; j++) {
upd(j);
}
for (int b = bi+1; b <= bmax; b++) {
f[b] -= 1;
}
}
// 回答 sum[x] 的值
int check(int x) {
int bi = bel[x], l = L[bi];
return f[bi] + b[bi][x-l];
}
// 询问有多少个 y >= x 使得 sum[y] = sum[x]
ll query(int x) {
int sumx = check(x);
int bi = bel[x], r = R[bi];
ll res = 0;
for (int j = x; j <= r; j++) {
if (check(j) == sumx) res++;
}
for (int b = bi+1; b <= bmax; b++) {
int flag = f[b];
// 有多少个 cnt[y] 使得 y + flag == sumx -> y == sumx - flag
if (sumx - flag >= 0 && sumx - flag < cnt[b].size()) res += cnt[b][sumx-flag];
}
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> p[i], pos[p[i]] = i;
fill(L, L+maxn, 1e9);
B = sqrt(n);
for (int i = 1; i <= n; i++) {
bel[i] = (i-1) / B + 1;
int bi = bel[i];
L[bi] = min(L[bi], i);
R[bi] = max(R[bi], i);
bmax = max(bmax, bi);
}
for (int i = 1; i <= bmax; i++) {
int j = 3*B + 10;
while (j--) cnt[i].push_back(0);
}
for (int i = 1; i <= n; i++) {
int bi = bel[i];
// 1: [1, B] -> flag = -2B, 2: [B+1, 2B] -> flag = 0, 3: flag = B
int flag = (bi - 3) * B;
f[bi] = flag;
b[bi].push_back(i - flag);
cnt[bi][i-flag]++;
}
ll ans = 0;
for (int i = 1; i <= n; i++) {
updall(pos[i]);
if (i > 1) updall(pos[i-1]);
ans += query(pos[i]);
}
cout << ans << endl;
}