主要记录一些遇到的线段树/分块例题。

例1 CF438D

题意

给定 N 个正整数和 M 个询问,询问有 3 种:

1 l r:输出 i=lrai

2 l r x:将 alar 的所有数取 mod x

3 k x:将 ak=x

其中,1N,M105,1ai,x109

题解

本题分块和线段树都可以做,我们这里用 线段树 来做。

主要是需要考虑 区间取模 怎么办?

回忆一下分块例题中的 区间开方,我们维护了一个额外的tag表示这个区间是否为 全0/全1,如果不是 全0/全1 就暴力开方。

取模操作同理,我们发现,如果 ai>x,那么 ai mod xai2,所以对于每个 ai,最多只会被 mod log(ai) 次!

所以,我们维护一个 区间最大值,取模时,检查一下 区间最大值是否大于 x

  1. 如果大于 x,就继续往下递归。
  2. 如果小于 x,就直接返回。

base case 就是区间长度为 1 时,直接对这个元素开方即可。

例2 CF558E

题意

给定一个长度为 n 的string,仅包含小写字母。给 q 个询问:

l,r,k:将string的 [l,r] 进行sort,如果 k=1 就升序,k=0 降序。

输出所有询问结束后的string。

其中,1n105,1q50000

题解

线段树来处理。

首先,string只包含小写字母。所以每个 node 可以维护一个 cnt[26] 代表这个node里的每个字母出现的次数。

其次,对于排序,我们在每个 node 中维护一个标记 k 来代表该区间是否排序好了。若 k=0 代表降序,k=1 代表升序,k=1 代表乱序。

最后,维护一个 lazy 标记,我们会注意到对于一个node而言,若 lazy=1,那么这个 node 必然是排序好了的!(要么 k=0 ,要么 k=1)。

有了以上信息,我们就可以进行 sort 操作了!


sort [L,R] 的时候,步骤如下:

  1. 提取出 [L,R] 内每个字母出现的次数。
  2. 求出 [L,R][l,mid](当前node 左child的范围)的区间交集 [l1,r1]
  3. 求出 [L,R][mid+1,r](当前node 右child的范围)的区间交集 [l2,r2]
  4. 用指针遍历 az(或者 za),根据升序/降序 将 字母出现的次数分别填充左child和右childcnt[] 中。(注意,这里的填充是指:先填充进一个 int* buf = new int[26]; 的动态数组,然后buf[] 作为参数,再往下传递,直到区间完全覆盖,再将 buf[] 的内容复制进 cnt[] 里)。
代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
using namespace std;
#include <bits/stdc++.h>
const int maxn = 1e5+5;
const int maxm = 2e5+10;
int n,q;
char arr[maxn];
struct Node {
int l,r,k,cnt[26];
bool lazy = 0;
} tr[4*maxn];
int tmp[26];
inline int len(int cur) { return tr[cur].r - tr[cur].l + 1; }
void push_up(int cur) {
int lc = cur<<1, rc = lc+1;
for (int i = 0; i < 26; i++) {
tr[cur].cnt[i] = tr[lc].cnt[i] + tr[rc].cnt[i];
}
if (tr[lc].k != -1 && tr[lc].k == tr[rc].k) tr[cur].k = tr[lc].k; // k = 1: increasing, k = 0: decreasing
else tr[cur].k = -1;
}
void put(int cur, int k) {
int lc = cur<<1, rc = lc+1;
memset(tr[lc].cnt, 0, sizeof(tr[lc].cnt));
memset(tr[rc].cnt, 0, sizeof(tr[rc].cnt));
memcpy(tmp, tr[cur].cnt, sizeof(tmp));
int lsz = len(lc), rsz = len(rc);
if (k) {
int p = 0;
while (p < 26 && lsz) {
int delta = min(lsz, tmp[p]);
tr[lc].cnt[p] += delta;
lsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p++;
}
while (p < 26 && rsz) {
int delta = min(rsz, tmp[p]);
tr[rc].cnt[p] += delta;
rsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p++;
}
} else {
int p = 25;
while (p >= 0 && lsz) {
int delta = min(lsz, tmp[p]);
tr[lc].cnt[p] += delta;
lsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p--;
}
while (p >= 0 && rsz) {
int delta = min(rsz, tmp[p]);
tr[rc].cnt[p] += delta;
rsz -= delta;
tmp[p] -= delta;
if (!tmp[p]) p--;
}
}
assert(lsz == 0);
assert(rsz == 0);
}
void push_down(int cur) {
if (!tr[cur].lazy) return;
int lc = cur<<1, rc = lc+1;
tr[cur].lazy = 0;
tr[lc].lazy = tr[rc].lazy = 1;
assert(tr[cur].k != -1);
int k = tr[cur].k;
tr[lc].k = k;
tr[rc].k = k;
put(cur,k);
}
void build(int cur, int L, int R) {
tr[cur].l = L, tr[cur].r = R;
if (L == R) {
memset(tr[cur].cnt, 0, sizeof(tr[cur].cnt));
tr[cur].cnt[arr[L]-'a'] = 1;
return;
}
int mid = (L+R) >> 1;
build(cur<<1, L, mid);
build(cur<<1|1, mid+1, R);
push_up(cur);
}
int ress[26]; // 每次query的结果会存到这里
void clear(int* buf) {
for (int i = 0; i < 26; i++) buf[i] = 0;
}
int inter(int l1, int r1, int l2, int r2) { // 求区间交集的长度
int l = max(l1,l2), r = min(r1,r2);
return max(0,r-l+1);
}
void update(int cur, int L, int R, int k, int* res) { // 注意参数里有个动态数组 res
int lc = cur<<1, rc = lc+1;
int l = tr[cur].l, r = tr[cur].r;
if (l >= L && r <= R) {
tr[cur].k = k;
tr[cur].lazy = 1;
for (int i = 0; i < 26; i++) tr[cur].cnt[i] = res[i]; // 区间完全覆盖,复制到 cnt 中。
clear(res); // 记得清空,之后可能还要用
return;
}
int mid = (l+r) >> 1;
int lsz = inter(l,mid,L,R), rsz = inter(mid+1,r,L,R);
int* buf = new int[26]; // 这里采用了动态数组
for (int i = 0; i < 26; i++) buf[i] = 0; //注意new出来的需要先清空一下,另外不能使用 memset(因为是指针)
if (k) {
int p = 0;
while (p < 26 && lsz) {
int delta = min(lsz, res[p]);
buf[p] += delta;
lsz -= delta;
res[p] -= delta;
if (!res[p]) p++;
}
if (L <= mid) {
update(lc, L, R, k, buf); // 传递 buf
}
while (p < 26 && rsz) {
int delta = min(rsz, res[p]);
buf[p] += delta;
rsz -= delta;
res[p] -= delta;
if (!res[p]) p++;
}
if (R > mid) {
update(rc, L, R, k, buf); // 传递 buf
}
} else {
int p = 25;
while (p >= 0 && lsz) {
int delta = min(lsz, res[p]);
buf[p] += delta;
lsz -= delta;
res[p] -= delta;
if (!res[p]) p--;
}
if (L <= mid) {
update(lc, L, R, k, buf); // 传递 buf
}
while (p >= 0 && rsz) {
int delta = min(rsz, res[p]);
buf[p] += delta;
rsz -= delta;
res[p] -= delta;
if (!res[p]) p--;
}
if (R > mid) {
update(rc, L, R, k, buf); // 传递 buf
}
}
delete[] buf;
assert(lsz == 0);
assert(rsz == 0);
push_up(cur);
}
void query(int cur, int L, int R) { // 求区间内每个字母出现的个数
int l = tr[cur].l, r = tr[cur].r;
if (l >= L && r <= R) {
for (int i = 0; i < 26; i++) ress[i] += tr[cur].cnt[i];
return;
}
int lc = cur<<1, rc = lc+1;
push_down(cur);
int mid = (l+r) >> 1;
if (L <= mid) query(lc, L, R);
if (R > mid) query(rc, L, R);
push_up(cur);
}
void printans() {
for (int i = 1; i <= n; i++) {
memset(ress, 0, sizeof(ress));
query(1,i,i);
for (int j = 0; j < 26; j++) {
if (ress[j]) {
printf("%c",(char)('a'+j));
ress[j]--;
break;
}
}
}
printf("\n");
}
int main() {
scanf("%d%d",&n,&q);
scanf("%s", arr+1);
build(1, 1, n);
while (q--) {
int l,r,k; scanf("%d%d%d",&l,&r,&k);
memset(ress, 0, sizeof(ress));
query(1,l,r);
update(1,l,r,k,ress);
}
printans();
}

例3 洛谷P1972 HH的项链

题意

给定一个长度为 N 的数组 a1,a2,,an,以及 m 个询问,每次询问 [L,R] 之间有多少个不同的数。

其中,1n,m,ai106

题解

我们可以发现,如果我们固定了询问的右端点 R,那么无论 L 为多少,在 R 的左侧的所有重复数字中,仅保留最靠右的一个 copy 即可。

例如 arr=1,3,2,1,7,1,那么我们在遍历到 i=4 时,我们仅需要保留最后一个 1 (也就是 index 为 4 的数字)。

由上,在处理 区间内不同的数 时,一个常见的套路是:

  1. 离线处理所有询问,按右端点 R 排序。
  2. 从左到右遍历数组,遍历到 i 时,对于所有 ai 的 copy,仅保留最靠右的那一个(也就是 i),之前的所有 copy 全部删除。
  3. 回答所有 [L,i] 的询问。

那么对于本题,第一步是离线处理询问,按右端点 R 排序。

第二步是遍历数组,遍历过程中维护一个 int pos[] 数组,其中 pos[val] 代表:在 arr[1i] 中,值为 val 的数 最靠右的 index

当我们遍历到 i 时,令 int val = arr[i],将 pos[val] 处的数字 删掉(更新线段树),然后将 i 处的数字 加入线段树,最后更新一下 pos[val] = i;

•本题中,删掉就是将 线段树的位置 i 的值 减去1,加上就是将 线段树的位置 i 的值 加上1

然后回答所有 以 i 为右端点的询问(求 [L,i] 的和即可)。

注:如果询问 在线 怎么办?可以使用主席树!(对于 pr[i] 建主席树)

具体做法参考:CF1422F Boring Queries

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e6+5;
int n, m, arr[maxn];
struct node {
int sum;
} tr[maxn<<2];
void push_up(int cur) {
tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}
void update(int cur, int l, int r, int p, int x) {
if (l == r) {
tr[cur].sum += x; return;
}
int mid = (l+r) >> 1;
if (p <= mid) update(cur<<1, l, mid, p, x);
else update(cur<<1|1, mid+1, r, p, x);
push_up(cur);
}
int query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return tr[cur].sum;
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += query(cur<<1, l, mid, L, R);
if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
return res;
}
int ans[maxn];
int pos[maxn];
struct Query {
int id,l,r;
} q[maxn];
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> arr[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l,r;
cin >> l >> r;
q[i] = {i,l,r};
}
sort(q+1,q+m+1, [](auto a, auto b) {
return a.r < b.r; // 根据右端点离线
});
int ptr = 1;
for (int i = 1; i <= n; i++) {
int val = arr[i];
if (pos[val]) update(1, 1, n, pos[val], -1); // 删去 pos[val]
update(1, 1, n, i, 1); // 加上 i
pos[val] = i; // 更新 pos[val]
while (ptr <= m && q[ptr].r == i) { // 回答所有以 i 作为右端点的询问
int id = q[ptr].id, L = q[ptr].l, R = q[ptr].r;
ans[id] = query(1, 1, n, L, R);
ptr++;
}
if (ptr > m) break;
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例4 CF1000F One Occurrence

题意

给定一个长度为 n 的数组,m 个询问,每次询问一个区间 [l,r],如果这个区间里存在只出现一次的数,输出这个数(如果有多个就输出任意一个),没有就输出 0

其中 n,m5×105,1ai5×105

题解

和例3类似的套路,也是离线处理询问(根据右端点sort),从左往右遍历,仅保留最靠右的复制。

问题在于怎么删除 和 加入数字?因为本题不再是求数量了,所以不能简单的加 1 或者 减 1

会发现,我们仅关心一个区间内是否存在 unique 的数字,对于一个询问 [L,R] 内,我们只要看,是否存在 i 使得 arr[i] 的前一个复制 不在 [L,R]。(也就是说,pos[val] 是否小于 L

那么,我们用线段树维护一下 区间最小值 即可,其中 [L,R] 的区间最小值就代表着所有i[L,R] 中,pos[arr[i]] 的最小值。如果一个区间 [L,R] 的最小值 L,那么答案不存在,否则答案存在。

那么,删除位置 i 就是将它这个位置上的值设为 inf

加入位置 i 的数,就是将它这个位置上的值设为 pos[arr[i]]

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
struct node {
int pre = 1e9, idx = 0; // pre 代表 pos[val] 的值,idx代表这个最小值对应的 index
} tr[maxn<<2];
int pos[maxn], n, m, arr[maxn];
struct Query {
int id,l,r;
} q[maxn];
void push_up(int cur) {
int lpre = tr[cur<<1].pre, rpre = tr[cur<<1|1].pre;
if (lpre < rpre) tr[cur].pre = lpre, tr[cur].idx = tr[cur<<1].idx;
else tr[cur].pre = rpre, tr[cur].idx = tr[cur<<1|1].idx;
}
void update(int cur, int l, int r, int p, int x) {
if (l == r) {
tr[cur].pre = x, tr[cur].idx = l;
return;
}
int mid = (l+r) >> 1;
if (p <= mid) update(cur<<1, l, mid, p, x);
else update(cur<<1|1, mid+1, r, p, x);
push_up(cur);
}
pii query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return {tr[cur].pre, tr[cur].idx};
int mid = (l+r) >> 1;
pii r1 = {1e9, 0}, r2 = {1e9, 0};
if (L <= mid) r1 = query(cur<<1, l, mid, L, R);
if (R > mid) r2 = query(cur<<1|1, mid+1, r, L, R);
if (r1.first > r2.first) return r2;
else return r1;
}
int ans[maxn];
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> arr[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l,r; cin >> l >> r;
q[i] = {i,l,r};
}
sort(q+1, q+m+1, [](auto a, auto b) {
return a.r < b.r;
});
int ptr = 1;
for (int i = 1; i <= n; i++) {
int val = arr[i];
update(1, 1, n, i, pos[val]); // 在i处 加入数字 pos[val]
if (pos[val]) update(1, 1, n, pos[val], 1e9); // 在pos[val] 处删除数字(设为 inf)
pos[val] = i;
while (ptr <= m && q[ptr].r == i) {
int L = q[ptr].l, R = q[ptr].r;
pii res = query(1, 1, n, L, R);
int id = q[ptr].id;
if (res.first >= L) ans[id] = 0; // 如果这个区间内,所有 pos[val] 都 >= L
else ans[id] = arr[res.second]; // 否则,答案存在
ptr++;
}
if (ptr > m) break;
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例5 CF803G Periodic RMQ Problem

题意

给定一个长度为 n 的数组,将这个数组复制为 k 份并且拼接在一起。

然后回答 Q 个询问,分两种询问:(所有询问都在拼接后的数组上进行)

1 l r x:将 [l,r] 中的所有元素改为 x

2 l r:询问 [l,r] 中的最小值。

其中,1n105,1k104,1Q105, l,r[1,n]

题解

注意到数组的总长度可以达到 109。但是询问只有 Q=105,我们于是想到了动态开点线段树。

但是注意到,这个数组是有初始值的,按理说应该把树建好,不能动态开点。

所以我们考虑 不建树,而是在开点的时候,把这个点的初始状态处理好(就是在没有任何修改的情况下,这个点的初始状态),然后再对这个点进行操作。

所以,对于一个点对应的区间 [l,r],怎么处理初始状态?

注意到,由于数组是一个循环节,所以可以分以下三种情况:

  1. [l,r] 的长度 n,它的初始最小值就是数组 [1,n] 的最小值。
  2. [l,r] 的长度 <n,且属于同一个循环节,那么用 ST表 预处理一下 [l mod n,r mod n] 的最小值即可。
  3. [l,r] 的长度 <n,且不属于同一个循环节,那么它的初始最小值就是 [l mod n,n][1,r mod n]

开点的函数是代码里的 build(),注意到只要我们访问到了 cur,并且需要访问它的任意一个 child 时,需要把左右两个子树都开点,这保证了 pushup() 的正确性。

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e7+5;
const int maxm = 1e5+5;
int n,k,Q;
int a[maxm];
int belong(int x) {
return (x-1) / n + 1;
}
struct Node {
int lc, rc, lazy, val;
} tr[maxn];
int id = 0;
int st[maxm][18];
int bin[maxm];
void build_st() {
bin[1] = 0, bin[2] = 1;
for (int i = 3; i < maxm; i++) bin[i] = bin[i>>1] + 1;
for (int i = 1; i <= n; i++) {
st[i][0] = a[i];
}
for (int j = 1; j < 18; j++) {
for (int i = 1; i + (1<<j) - 1 <= n; i++) {
st[i][j] = min(st[i][j-1], st[i+(1<<(j-1))][j-1]);
}
}
}
int ask_st(int l, int r) {
if (l > r) swap(l,r);
int j = bin[r-l+1];
return min(st[l][j], st[r-(1<<j)+1][j]);
}
void build(int& cur, int l, int r) {
if (cur) return;
cur = ++id;
if (r-l+1 >= n) {
tr[cur].val = ask_st(1, n);
return;
}
int bl = belong(l), br = belong(r);
l %= n, r %= n;
if (!l) l = n;
if (!r) r = n;
if (bl == br) {
tr[cur].val = ask_st(l,r);
return;
}
// bl != br
tr[cur].val = min(ask_st(l, n), ask_st(1, r));
}
void push_down(int cur) {
if (!tr[cur].lazy) return;
int lazy = tr[cur].lazy;
tr[cur].lazy = 0;
int lc = tr[cur].lc, rc = tr[cur].rc;
tr[lc].lazy = tr[lc].val = tr[rc].lazy = tr[rc].val = lazy;
}
void push_up(int cur) {
tr[cur].val = min(tr[tr[cur].lc].val, tr[tr[cur].rc].val);
}
void update(int cur, int l, int r, int L, int R, int x) {
if (l >= L && r <= R) {
tr[cur].lazy = tr[cur].val = x;
return;
}
int mid = (l+r) >> 1;
build(tr[cur].lc, l, mid);
build(tr[cur].rc, mid+1, r);
push_down(cur);
if (L <= mid) update(tr[cur].lc, l, mid, L, R, x);
if (R > mid) update(tr[cur].rc, mid+1, r, L, R, x);
push_up(cur);
}
int query(int cur, int l, int r, int L, int R) {
if (l >= L && r <= R) return tr[cur].val;
int mid = (l+r) >> 1;
build(tr[cur].lc, l, mid);
build(tr[cur].rc, mid+1, r);
push_down(cur);
int lres = 1e9, rres = 1e9;
if (L <= mid) lres = query(tr[cur].lc, l, mid, L, R);
if (R > mid) rres = query(tr[cur].rc, mid+1, r, L, R);
push_up(cur);
return min(lres, rres);
}
int rt = 0;
int main() {
fastio;
cin >> n >> k;
for (int i = 1; i <= n; i++) cin >> a[i];
build_st();
cin >> Q;
build(rt, 1, n*k);
for (int i = 1; i <= Q; i++) {
int op,l,r,x;
cin >> op;
if (op == 1) {
cin >> l >> r >> x;
update(1, 1, n*k, l, r, x);
} else {
cin >> l >> r;
int res = query(1, 1, n*k, l, r);
cout << res << "\n";
}
}
}

例6 CF gym103687F. Easy Fix

题意

给定一个 permutation p1,p2pn

定义 Aii 左边比 pi 小的元素数量,Bii 右边比 pi 大的元素数量,Ci=min(Ai,Bi)

给定 m 个询问,每次询问 L,R,回答在交换 pL,pR 的位置以后,i=1nCi 的值?

其中,n105,m2×105

题解

先讨论一下交换后有哪些元素受到影响?

很明显,L,R 位置会受到影响,i(L,R)pi(pL,pR) 的会受到影响,其他的均无影响。


我们先讨论 i(L,R)pi(pL,pR) 这种情况:

Case1: 如果 au<av,那么 Ai1,Bi+1

 Case1.1: 如果 AiBi,那么 Ci1

 Case1.2: 如果 Ai>Bi+1,那么 Ci+1

Case2: 如果 au>av,那么 Ai+1,Bi1

 Case2.1: 如果 AiBi,那么 Ci1

 Case2.2: 如果 Ai<Bi1,那么 Ci+1


然后再讨论一下 L,R 受到的影响:

只要先把 CL,CR 的值从询问中减掉,然后重新找一下换了之后的 CL,CR 即可。

• 需要注意当 pL>pR 时,换了之后 R 位置上的 pL 计算 Ai 时会将 pR 忽略掉,所以要加 1


好的,现在该讨论一下怎么获得 Ai,Bi,Ci

首先,获得 Ai 可以直接用树状数组维护 pi(作为index),树状数组上的值为 0 和 1。遍历一下就可以求出初始状态的 Ai 了。

然后 Bi=pi1Ai,很好理解,这样 Ci 也有了。


再讨论一下怎么处理询问?

我们将 Case1, Case2 分开处理:如果我们开一个线段树/树状数组,index 为 pi,值为 1,0,1 来记录 Case1.1, Case1.2 的情况,那么:

每次询问就是询问一个区间 (L,R) 内,满足 (pL,pR) 的值的一些东西的和。

似乎是 主席树

确实,不过本题询问离线,Stupidcdd 教导我:

能用主席树做的,离线情况下就能用线段树/树状数组。

回忆一下主席树的作用:对于每个区间 [L,R],维护一个权值线段树,原理其实是每个点维护一个线段树,R 上的线段树减去 L1 位置的线段树而已。

那么询问离线的时候,我们只有一个树状数组,可以在遍历的过程中,获得前缀 [1i] 对应的树状数组,但是我们没有持久化,所以保存不下来。

因为树状数组的历史版本保存不下来,不妨从询问下手?

离线的一个常见套路:

对于询问的端点 L,R,每个端点开一个 vector 储存所有的询问,遍历到这个点时处理所有对应的询问,加到答案上。

这个题就是这样,我们想要知道 (L,R) 内满足 (pL,pR) 的值的一些东西的和,那么不妨在遍历到 L 的时候询问一下 (pL,pR),在遍历到 R1 的时候询问一下 (pL,pR),然后一减,不就是答案了么?

copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
struct Query_Info {
int pl, pr; // 询问 [pl, pr] 的部分
int f; // 符号
int idx; // 询问的 idx
int type; // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn]; // qinfo[i] 端点为 i 的需要询问
struct Single_Query {
int x; // 询问 < x 有多少个
int idx; // 询问的 idx
int diff; // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn]; // 单点询问 < x 有多少个
ll ans[maxn];
...
...
...
// (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
for (int j = 1; j <= m; j++) {
auto q = queries[j];
int u = q.first, v = q.second;
if (u != v) {
int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
int type = (p[u] < p[v] ? 1 : 2);
qinfo[u+1].push_back({mn, mx, -1, j, type}); // 由于处理时,还没加上,所以全部应该 + 1 (u -> u+1, v-1 -> v)
qinfo[v].push_back({mn, mx, 1, j, type});
ans[j] -= (C[u] + C[v]); // 先去掉 C[u] + C[v] 的影响,后面直接加上
single[u].push_back({p[v], j, 0});
single[v].push_back({p[u], j, p[u] > p[v]});
}
}
for (int i = 1; i <= n; i++) {
for (auto q : qinfo[i]) {
if (q.type == 1) {
ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
} else {
ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
}
}
for (auto q : single[i]) {
int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
ans[q.idx] += min(AA, BB);
}
if (A[i] <= B[i]) tr1.update(p[i], -1);
if (A[i] > B[i] + 1) tr1.update(p[i], 1);
if (A[i] >= B[i]) tr2.update(p[i], -1);
if (A[i] < B[i] - 1) tr2.update(p[i], 1);
tr3.update(p[i], 1);
}

在这个题中:

  1. (L,R) 我们通过离线询问,答案相减来实现。
  2. (pL,pR) 我们通过权值线段树/权值树状数组来实现。
  3. 树状数组上面维护了 Case1.1, 1.2 (或者 2.1, 2.2) 的情况。

总结一下思路:

离线的关键在于从左向右遍历的时候,获得每个前缀对应的树状数组。

在离线 (L,R) 时,将询问的信息扔给每个端点的 vector,在遍历到的时候再处理,处理时直接对 ans[] 进行操作。

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+55;
const int maxm = 2e5+5;
int n, p[maxn], m;
pii queries[maxm];
struct BIT {
int n, tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= n) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
int query(int p) {
int ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr, tr1, tr2, tr3;
struct Query_Info {
int pl, pr; // 询问 [pl, pr] 的部分
int f; // 符号
int idx; // 询问的 idx
int type; // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn]; // qinfo[i] 端点为 i 的需要询问
struct Single_Query {
int x; // 询问 < x 有多少个
int idx; // 询问的 idx
int diff; // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn]; // 单点询问 < x 有多少个
ll ans[maxn];
ll csum = 0;
int main() {
fastio;
cin >> n;
tr.n = tr1.n = tr2.n = tr3.n = n;
for (int i = 1; i <= n; i++) cin >> p[i];
cin >> m;
for (int i = 1; i <= m; i++) {
int l, r; cin >> l >> r;
if (l >= r) swap(l, r);
queries[i] = {l, r};
}
for (int i = 1; i <= n; i++) {
A[i] = tr.query(p[i] - 1); // 左边比它小
B[i] = p[i] - 1 - A[i]; // 右边比它小
C[i] = min(A[i], B[i]);
csum += C[i];
tr.update(p[i], 1);
}
for (int i = 1; i <= m; i++) ans[i] = csum;
// (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
for (int j = 1; j <= m; j++) {
auto q = queries[j];
int u = q.first, v = q.second;
if (u != v) {
int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
int type = (p[u] < p[v] ? 1 : 2);
qinfo[u+1].push_back({mn, mx, -1, j, type}); // 由于处理时,还没加上,所以全部应该 + 1 (u -> u+1, v-1 -> v)
qinfo[v].push_back({mn, mx, 1, j, type});
ans[j] -= (C[u] + C[v]); // 先去掉 C[u] + C[v] 的影响,后面直接加上
single[u].push_back({p[v], j, 0});
single[v].push_back({p[u], j, p[u] > p[v]});
}
}
for (int i = 1; i <= n; i++) {
for (auto q : qinfo[i]) {
if (q.type == 1) {
ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
} else {
ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
}
}
for (auto q : single[i]) {
int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
ans[q.idx] += min(AA, BB);
}
if (A[i] <= B[i]) tr1.update(p[i], -1);
if (A[i] > B[i] + 1) tr1.update(p[i], 1);
if (A[i] >= B[i]) tr2.update(p[i], -1);
if (A[i] < B[i] - 1) tr2.update(p[i], 1);
tr3.update(p[i], 1);
}
for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例7 Atcoder ABC237G. Range Sort Query

题意

给定一个长度为 n 的permutation p1,p2,pn,和一个正整数 x

给定 q 个询问,询问有两种:

1 L R:将 p[LR] 从小到大排序。

2 L R:将 p[LR] 从大到小排序。

所有询问结束后,回答 x 所在的index。

其中,n,q2×105,x[1,n]

题解

经典套路题。一般看见这种每次询问然后排序的,一般来讲只有在整个数组中 distinct 的数非常少的时候才能做。

这个题注意到我们只关心 x,于是我们可以把剩下的数分为比 x 小和比 x 大的。

而除了 x 以外所有数之间的相对顺序我们并不关心。

所以我们可以直接开两棵线段树,线段树里只储存 0 或者 1,代表这个位置是否存在比 x 大/小 的数。

而 sort 的时候,我们就相当于把两棵线段树的 [L,R] 中,看一下 <x>x 分别有多少个,然后清空整个线段,然后重新进行区间赋值即可。

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5+5;
const int maxm = 1e5+5;
struct Tree_Node {
int sum = 0;
bool flag = 0;
int val; // 区间赋的值
};
struct Segment_Tree {
Tree_Node tr[maxn<<2];
void push_up(int cur) {
tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}
void push_down(int cur, int l, int r) {
if (!tr[cur].flag) return;
int val = tr[cur].val;
tr[cur].flag = 0;
int mid = (l+r) >> 1;
int lc = cur<<1, rc = cur<<1|1;
int llen = (mid - l + 1), rlen = (r - mid);
tr[lc].flag = tr[rc].flag = 1;
tr[lc].val = tr[rc].val = val;
tr[lc].sum = llen * val, tr[rc].sum = rlen * val;
}
// 区间赋值为 x
void update(int cur, int l, int r, int L, int R, int x) {
if (L > R) return;
if (L <= l && R >= r) {
tr[cur].flag = 1;
tr[cur].val = x;
tr[cur].sum = (r-l+1) * x;
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update(cur<<1, l, mid, L, R, x);
if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
// 区间求和
int query(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].sum;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += query(cur<<1, l, mid, L, R);
if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
return res;
}
} tr1, tr2; // 比它小,比它大
int n, q, x, p;
int a[maxn];
int main() {
cin >> n >> q >> x;
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (a[i] == x) p = i;
}
for (int i = 1; i <= n; i++) {
if (a[i] < x) tr1.update(1, 1, n, i, i, 1);
if (a[i] > x) tr2.update(1, 1, n, i, i, 1);
}
while (q--) {
int c, L, R; cin >> c >> L >> R;
int sum1 = tr1.query(1, 1, n, L, R);
int sum2 = tr2.query(1, 1, n, L, R);
tr1.update(1, 1, n, L, R, 0);
tr2.update(1, 1, n, L, R, 0);
if (c == 1) {
tr1.update(1, 1, n, L, L+sum1-1, 1);
tr2.update(1, 1, n, R-sum2+1, R, 1);
if (p >= L && p <= R) p = L + sum1;
} else {
tr2.update(1, 1, n, L, L+sum2-1, 1);
tr1.update(1, 1, n, R-sum1+1, R, 1);
if (p >= L && p <= R) p = L + sum2;
}
}
cout << p << endl;
}

例8 Eerie Subarrays

题意

给定一个 1n 的permutation,求有多少个 subarray 使得其最左边的元素是这个subarray的中位数?

其中,n2×105

题解

假设我们正在考虑的数是 ai,那么只用考虑起点从 i 开始的subarray。

将所有 <ai 的数设为 1,所有 >ai 的数设为 1,那么问题转化为:

R>i 的数量,使得 sum[iR]=0 (等价于使得 sum[i1]=sum[R])。

线段树没法统计这个,但可以用分块来统计!


同时注意到给定的是一个 1n 的permutation,所以我们从 1 开始,初始状态下所有值均为 1,而 1 所在的位置为 0,之后考虑下一个数的时候就是将 1 所在的位置设为 12 所在的位置设为 0,以此类推。

这样每次更新只需要更改两个位置的数,而更改一个位置的数会影响后面的所有的 sum 值。对于块内的,暴力更新,对于后面的所有块,给块加上一个标记即可。所以每次更新复杂度是 O(n)


最后如何统计答案?对于块内,暴力统计,对于一个整块,我们维护块内每个元素出现的次数即可。

如何维护?显然不能对每个块都开 1n 的数组。但注意到每个块里面最多 B=n 个元素,并且每个元素被暴力减的次数 2B。所以我们可以先将每个块的所有元素初始化到 [2B,3B] 的范围内(通过更新标记来初始化),然后每个块里面只要维护一个长度为 3B 的vector即可,总空间复杂度就是 O(n) 了。

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
int n, p[maxn], pos[maxn];
vector<int> b[maxn];
int f[maxn], B;
// [1,B] -> vector<int>, [B+1, 2B] -> vector<int>
int L[maxn], R[maxn];
int bel[maxn];
int bmax = 0;
vector<int> cnt[maxn]; // cnt[bi][i]: 代表 i 在 block(bi) 中的出现次数
// 将位置 x 的值减 1
void upd(int x) {
int bi = bel[x];
int l = L[bi];
int ori = b[bi][x-l];
b[bi][x - l] -= 1;
cnt[bi][ori]--;
cnt[bi][ori-1]++;
}
// 将位置 x 的值减 1, 说明所有 >= x 的值都要减 1
void updall(int x) {
int bi = bel[x];
int r = R[bi];
for (int j = x; j <= r; j++) {
upd(j);
}
for (int b = bi+1; b <= bmax; b++) {
f[b] -= 1;
}
}
// 回答 sum[x] 的值
int check(int x) {
int bi = bel[x], l = L[bi];
return f[bi] + b[bi][x-l];
}
// 询问有多少个 y >= x 使得 sum[y] = sum[x]
ll query(int x) {
int sumx = check(x);
int bi = bel[x], r = R[bi];
ll res = 0;
for (int j = x; j <= r; j++) {
if (check(j) == sumx) res++;
}
for (int b = bi+1; b <= bmax; b++) {
int flag = f[b];
// 有多少个 cnt[y] 使得 y + flag == sumx -> y == sumx - flag
if (sumx - flag >= 0 && sumx - flag < cnt[b].size()) res += cnt[b][sumx-flag];
}
return res;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> p[i], pos[p[i]] = i;
fill(L, L+maxn, 1e9);
B = sqrt(n);
for (int i = 1; i <= n; i++) {
bel[i] = (i-1) / B + 1;
int bi = bel[i];
L[bi] = min(L[bi], i);
R[bi] = max(R[bi], i);
bmax = max(bmax, bi);
}
for (int i = 1; i <= bmax; i++) {
int j = 3*B + 10;
while (j--) cnt[i].push_back(0);
}
for (int i = 1; i <= n; i++) {
int bi = bel[i];
// 1: [1, B] -> flag = -2B, 2: [B+1, 2B] -> flag = 0, 3: flag = B
int flag = (bi - 3) * B;
f[bi] = flag;
b[bi].push_back(i - flag);
cnt[bi][i-flag]++;
}
ll ans = 0;
for (int i = 1; i <= n; i++) {
updall(pos[i]);
if (i > 1) updall(pos[i-1]);
ans += query(pos[i]);
}
cout << ans << endl;
}