主席树
Contents
介绍
主席树全名叫做 可持久化权值线段树,一般用于一个数组上,有以下的功能:
- 对于每一个区间 都能 开一个权值线段树。
- 能够维护数组的 历史版本。(仅用于单点插入/修改)
时间复杂度:$O(n\log n)$,空间复杂度:$O(n \log n)$。
思想
节点的复制
主席树的主要思想在于,对于在线段树上的单点修改,如果要维护多个版本(修改前和修改后),我们可以复制出新的节点,来维护新版本的信息。
由于单点修改仅会影响一条链(从叶子节点,一直到根节点),所以每个版本最多会复制出 $O(\log n)$ 个新节点。
如上图,橙色部分就是一个新版本,复制出来了一条链。
• 因为复制节点,所以也需要 动态开点。
来一道例题:
例1 洛谷 P3919 【模板】可持久化线段树 1(可持久化数组)
题意
维护一个长度为 $N$ 的数组,共有 $M$ 次询问。询问格式如下:
$v ~ 1 ~ p ~ x$:在版本 $v$ 的基础上,将 $a_p$ 修改为 $x$
$v ~ 2 ~ p$:在版本 $v$ 的基础上,询问 $a_p$ 的值
每次询问后,都生成一个新版本。(所以共有 $M+1$ 个版本)
题解
可持久化的操作在上面说过了。对于每一次修改,都 复制一条链。如果是操作 $2$,复制根节点就可以了。
可持久化的复制节点方式和普通的动态开点略有不同,主要体现在:
- 不需要看
cur == 0
与否,直接复制即可,并且将复制后的编号返回。 - 需要记录 上一个版本 的 同位置 节点的标号
pre
- 需要
build
操作root[0] = build(1,n)
,因为初始状态无论是否为空,都需要把线段树开好,否则后面无法复制(也无法进行相减操作)。
访问不同版本的线段树时,就访问它们的 root
即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+5;
const int maxm = 3e7;
int n,m,root[maxn],id, arr[maxn];
struct node {
int lc,rc,val;
} tr[maxm];
int build(int l, int r) { // 参数中没有 cur
int cur = ++id; // 直接添加
if (l == r) {
tr[cur].val = arr[l];
return cur;
}
int mid = (l+r) >> 1;
tr[cur].lc = build(l, mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
// 将位置p 的值修改为 x
// pre 是前一个版本的 同位置节点
int insert(int pre, int l, int r, int p, int x) {
int cur = ++id;
tr[cur] = tr[pre];
if (l == r) {
tr[cur].val = x;
return cur;
}
int mid = (l+r) >> 1;
if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p, x);
if (p > mid) tr[cur].rc = insert(tr[pre].rc, mid+1, r, p, x);
return cur;
}
int query(int cur, int l, int r, int p) {
if (l == r) return tr[cur].val;
int mid = (l+r) >> 1;
if (p <= mid) return query(tr[cur].lc, l, mid, p);
else return query(tr[cur].rc, mid+1, r, p);
}
void init() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> arr[i];
root[0] = build(1,n); // 注意需要 build(1,n),这是版本 0
}
int main() {
init();
for (int i = 1; i <= m; i++) {
int v, op, p;
cin >> v >> op >> p;
if (op == 1) {
int x; cin >> x;
root[i] = insert(root[v], 1, n, p, x);
} else {
int res = query(root[v], 1, n, p); // 版本 v 的根
cout << res << "\n";
root[i] = root[v];
}
}
}
区间询问(两个线段树相减)
区间问题常用的一个思想是 前缀和。
那么对于一个数组 $a_1,a_2,…,a_n$,我们可以维护 $n+1$ 个版本的权值线段树。分别维护了 $sum_0, sum_1, sum_2, …, sum_n$ 的信息。
例如,$sum_0$ 是一个空的权值线段树(已经 build()
过的),$sum_3$ 这个权值线段树维护的就是 $a_1,a_2,a_3$ 这个数组的信息。
要求 $sum_3$,我们在 $sum_2$ 的基础上,将 $a_3$ 的信息加进 $sum_2$(单点修改),形成一个新版本的权值线段树即可。
那么,如果我们要求 $[L,R]$ 这个区间对应的权值线段树,只要求出 $sum_r - sum_{l-1}$ 对应的权值线段树就可以了!
线段树之间怎么相减?把对应节点维护的值相减一下即可!
以上就是主席树的全部内容了,本质上是 节点复制 + 线段树相减。
这样,对于每一个区间,都可以获得一个权值线段树。
例题
例2 洛谷P3834 【模板】可持久化线段树 2(主席树)
题意
给定 $N$ 个整数 $a_1,a_2,…,a_n$,和 $m$ 个询问,每次询问 $[L,R]$ 之间的第 $k$ 小值。保证询问合法。
其中,$1 \leq n,m \leq 2 \times 10^5, |a_i| \leq 10^9$
题解
思考一个问题能否用主席树,我们可以先思考,对于整个数组,我们能否用权值线段树解决?
答案是可以的!如果我们要求整个数组的第 $k$ 小,可以将所有数字先离散化成排名,然后用权值线段树来维护各个排名的数量。求可以求出第 $k$ 小了!
所以主席树可以解决,步骤如下:
- 对整个数组进行离散化
- 维护主席树,对于每一个 $[L,R]$ 都可以获得一个权值线段树,然后可以求得第 $k$ 小。
注:离散化的一个很方便的写法是 struct
+ sort()
,然后遍历 sort 后的数组:
int arr[maxn];
int N = 0, rk[maxn], val[maxn];
// N: 排名数,rk[i]: arr[i]的排名,val[i]: 排名为i的数字的值
void init() {
for (int i = 1; i <= n; i++) cin >> arr[i].val, arr[i].id = i;
sort(arr+1, arr+n+1, [](auto a, auto b) {
return a.val < b.val;
});
rk[arr[1].id] = ++N;
val[1] = arr[1].val;
for (int i = 2; i <= n; i++) {
if (arr[i].val > arr[i-1].val) N++;
rk[arr[i].id] = N;
val[N] = arr[i].val;
}
}
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 1e7;
int n,m;
struct Num {
int val, id;
} arr[maxn];
int rk[maxn], val[maxn]; // rk[1] 代表 arr[1] 离散化后的值(排名), val[1] 代表整个array中第1小的值
int N = 0;
struct node {
int lc, rc, cnt;
} tr[maxm];
int root[maxn], id = 0;
int build(int l, int r) {
int cur = ++id;
if (l == r) {
return cur;
}
int mid = (l+r) >> 1;
tr[cur].lc = build(l,mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
// pre 是上个版本的节点
// 令 p 位置的 cnt += 1
int insert(int pre, int l, int r, int p) {
int cur = ++id;
tr[cur] = tr[pre]; // 复制一份上个版本
tr[cur].cnt++; // 添加了一个节点
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid) tr[cur].lc = insert(tr[cur].lc, l, mid, p);
if (p > mid) tr[cur].rc = insert(tr[cur].rc, mid+1, r, p);
return cur;
}
// 查询 tr[pre,cur] 之间的第k小,返回具体的值
int query(int pre, int cur, int l, int r, int k) {
if (l == r) return val[l];
int prelc = tr[pre].lc, lc = tr[cur].lc;
int prerc = tr[pre].rc, rc = tr[cur].rc;
int mid = (l+r) >> 1;
int lcnt = tr[lc].cnt - tr[prelc].cnt;
if (lcnt >= k) return query(prelc, lc, l, mid, k); // 如果左边有 >= k 个数
else return query(prerc, rc, mid+1, r, k-lcnt); // 否则只看右边
}
void init() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> arr[i].val, arr[i].id = i;
sort(arr+1, arr+n+1, [](auto a, auto b) {
return a.val < b.val;
});
rk[arr[1].id] = ++N;
val[1] = arr[1].val;
for (int i = 2; i <= n; i++) {
if (arr[i].val > arr[i-1].val) N++;
rk[arr[i].id] = N;
val[N] = arr[i].val;
}
root[0] = build(1, N);
for (int i = 1; i <= n; i++) {
root[i] = insert(root[i-1], 1, N, rk[i]);
}
}
int main() {
init();
for (int i = 1; i <= m; i++) {
int l,r,k;
cin >> l >> r >> k;
int ans = query(root[l-1], root[r], 1, N, k);
cout << ans << "\n";
}
}
例3 CF1422F Boring Queries
题意
给定 $n$ 个正整数 $a_1, a_2, … , a_n$,$q$ 次询问。
每次询问 $[x,y]$,为了保证询问在线,记上一次询问的答案为 $last$,然后令 $L = ((x+last) \text{ mod } n) + 1$,$R = ((y+last) \text{ mod } n) + 1$,如果 $L > R$,则交换 $L,R$。
每次询问,回答 $[L,R]$ 之间所有数的 LCM,答案对 $10^9+7$ 取模。
其中,$1 \leq n,q \leq 10^5, 1 \leq a_i \leq 2 \times 10^5, 1 \leq x,y \leq 10^5$
题解
本题主要有两个难点:
- 所有询问在线。
- $LCM$ 数字极大,且需要取模,无法正常维护。
因为 $LCM$ 的值极大,且取模,可以考虑 质因数分解。
但是,$a_i \leq 2 \times 10^5$,我们无法直接维护每一个质因子。
这时候可以考虑 根号分治,将质因子分为两部分:
一部分是 $\leq \sqrt {(2 \times 10^5)}$ 的质因子。
还有一部分是 $> \sqrt {(2 \times 10^5)}$ 的质因子。
对于 $\leq \sqrt {(2 \times 10^5)}$ 的质因子,我们会发现这种小因子只有 $87$ 个。所以我们只要维护 $87$ 个 ST表 来维护每个小因子在区间内的最大次数即可。
对于 $> \sqrt {(2 \times 10^5)}$ 的质因子,我们会发现它们的出现次数最多为 $1$,并且对于任意一个 $a_i$,它最多只能包含一个这样的大因子。
所以对于大因子而言,求 $LCM$ 就转化为:
求一个区间内,有哪些不同的大因子出现过,将这些 unique 的大因子乘起来就可以了!
那么,如何解决如下的问题?
给定一个数组,询问一个区间,求该区间内 所有不同的数的乘积。
这个问题,我们在 HH的项链 中见到过。
但是这个题 强制在线,没法用上面的离线方法来解决。
所以,我们维护一个 pr[]
数组,其中 pr[i]
代表:对于 i
,上一个值等于 arr[i]
的 index 的值。
(即:pr[i] = j
,其中 arr[j] = arr[i], j < i
)
在询问 $[L,R]$ 时,我们将问题转化为:
求区间内所有
arr[i]
的乘积,使得:
$i \in [L,R]$
pr[i] < L
这个问题可以用 主席树 解决。我们根据 pr[i]
的值来建主席树,树上节点的值就维护乘积。线段树相减 就用 乘积的逆元 来实现。
最后总结一下本题的步骤:
- 欧拉筛求出所有质数。
- 维护 $87$ 个小于等于 $450$ 的小质因子,建立 $87$ 个ST表。
- 建好大因子的主席树。
本题空间卡的非常紧,我们将 ST表 的数组改为了
short
类型,可以避免 MLE。
代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e5+2;
const int maxm = 2e5+2;
vector<int> primes;
bool is_prime[maxm];
int n, arr[maxn];
void euler() {
fill(is_prime, is_prime+maxm, 1);
for (int cur = 2; cur < maxm; cur++) {
if (is_prime[cur]) primes.push_back(cur);
for (auto p : primes) {
if (p * cur >= maxm) break;
is_prime[p*cur] = 0;
if (cur % p == 0) break;
}
}
}
short st[87][maxn][18];
short bin[maxn];
int ask_st(int i, int l, int r) {
int len = (r-l+1);
int k = bin[len];
return max(st[i][l][k], st[i][r-(1<<k)+1][k]);
}
void build_st() {
bin[1] = 0;
bin[2] = 1;
for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
for (int i = 0; i <= 86; i++) {
int p = primes[i];
for (int j = 1; j <= n; j++) {
int cnt = 0;
while (arr[j] % p == 0) {
arr[j] /= p, cnt++;
}
st[i][j][0] = cnt;
}
for (int k = 1; k < 18; k++) {
for (int j = 1; j+(1<<k)-1 <= n; j++) {
st[i][j][k] = max(st[i][j][k-1], st[i][j+(1<<(k-1))][k-1]);
}
}
}
}
struct node {
int lc,rc;
ll val = 1;
} tr[maxn<<5];
int id, root[maxn], pr[maxn];
inline ll qpow(ll a, ll b) {
if (!b) return 1;
ll res = 1;
while (b) {
if (b&1)
(res *= a) %= mod;
(a *= a) %= mod;
b >>= 1;
}
return res;
}
inline ll inv(ll a) {
return qpow(a, mod-2);
}
int build(int l, int r) {
int cur = ++id;
if (l == r) return cur;
int mid = (l+r) >> 1;
tr[cur].lc = build(l,mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
void push_up(int cur) {
int lc = tr[cur].lc, rc = tr[cur].rc;
tr[cur].val = (tr[lc].val * tr[rc].val) % mod;
}
int insert(int pre, int l, int r, int p, ll x) {
int cur = ++id;
tr[cur] = tr[pre];
(tr[cur].val *= x) %= mod;
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid)
tr[cur].lc = insert(tr[pre].lc, l, mid, p, x);
else
tr[cur].rc = insert(tr[pre].rc, mid+1, r, p, x);
push_up(cur);
return cur;
}
// 询问 [pre, cur] 之间的大质数之积, 且保证 pr[i] < p
ll query(int pre, int cur, int l, int r, int p) {
if (r < p) return (tr[cur].val * inv(tr[pre].val)) % mod;
if (l >= p) return 1;
int mid = (l+r) >> 1;
ll res = 1;
(res *= query(tr[pre].lc, tr[cur].lc, l, mid, p)) %= mod;
(res *= query(tr[pre].rc, tr[cur].rc, mid+1, r, p)) %= mod;
return res;
}
int pos[maxm];
void init() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> arr[i];
euler();
build_st();
root[0] = build(0,n);
for (int i = 1; i <= n; i++) {
int val = arr[i];
pr[i] = pos[val];
pos[val] = i;
}
for (int i = 1; i <= n; i++) {
root[i] = insert(root[i-1], 0, n, pr[i], arr[i]);
}
}
ll last = 0;
ll Query(ll L, ll R) {
L += last, R += last; L %= n, R %= n; L++, R++;
if (L > R) swap(L,R);
ll big = query(root[L-1], root[R], 0, n, L);
ll small = 1;
for (int i = 0; i <= 86; i++) {
ll p = primes[i];
int c = ask_st(i,L,R);
(small *= qpow(p,c)) %= mod;
}
last = (small * big) % mod;
return last;
}
int main() {
init();
int q; cin >> q;
while (q--) {
ll L,R; cin >> L >> R;
cout << Query(L,R) << "\n";
}
}
例4 洛谷P2468 [SDOI2010]粟粟的书架
题意
给定一个 $R \times C$ 的矩阵,矩阵中所有元素均为正整数,有 $M$ 个询问,每次询问:
$x_1 ~ y_1 ~ x_2 ~ y_2 ~ H$:求 $(x_1,y_1)$ 和 $(x_2,y_2)$ 之间的矩形中,最少取多少个数字可以让数字之和 $\geq H$?
数据范围:
对于 $50$ % 的数据,有 $R,C \leq 200, M \leq 2 \times 10^5$
对于另外 $50$ % 的数据,有 $R = 1, C \leq 5 \times 10^5, M \leq 2 \times 10^4$
矩阵中所有元素满足值在 $[1, 1000]$ 之间,$H \leq 2 \times 10^9$
题解
我们需要让数字之和尽量大,那么每次询问就选择最大的那些数。
可以维护主席树,以元素的值作为权值,节点之中维护 $cnt$ 和 $sum$。
对于 $R = 1$ 的情况很好解决。那么对于 $R \leq 200$ 呢?
有两种方法,
法一:维护 $200$ 棵主席树
询问的过程中,把每一行的线段树都进行 相加 (具体实现通过维护 vector<int> pre, vector<int> cur
)。
法二:维护二维前缀和
sum[i][j][k]
代表 $(1,1)$ 和 $(i,j)$ 之间矩阵之中,数值 $\geq k$ 的数字的和
cnt[i][j][k]
代表 $(1,1)$ 和 $(i,j)$ 之间矩阵之中,数值 $\geq k$ 的数字的数量
询问的时候,二分一下 $k$ 就可以了。
以下给出主席树代码。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
struct node {
int lc, rc, sum, cnt;
} tr[maxn<<5];
int n,m,Q, id = 0;
vector<int> sum[201];
vector<int> arr[201];
vector<int> root[201];
int build(int l, int r) {
int cur = ++id;
if (l == r) return cur;
int mid = (l+r) >> 1;
tr[cur].lc = build(l, mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
int insert(int pre, int l, int r, int p) {
int cur = ++id;
tr[cur] = tr[pre];
tr[cur].sum += p;
tr[cur].cnt++;
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p);
else tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
return cur;
}
pii query(vector<int>& pre, vector<int>& cur, int l, int r, int need) {
if (need <= 0) return {0,0};
int allsum = 0, allcnt = 0;
for (int i = 0; i < pre.size(); i++) {
int p = pre[i], c = cur[i];
allsum += tr[c].sum - tr[p].sum; // 200棵线段树相加
allcnt += tr[c].cnt - tr[p].cnt;
}
if (allsum <= need) return {allcnt, allsum};
if (l == r) return {(need/l + (need % l > 0)), (need/l + (need % l > 0)) * l};
int mid = (l+r) >> 1;
vector<int> plc, lc, prc, rc;
for (int i = 0; i < pre.size(); i++) {
int p = pre[i], c = cur[i];
plc.push_back(tr[p].lc);
lc.push_back(tr[c].lc);
prc.push_back(tr[p].rc);
rc.push_back(tr[c].rc);
}
int needcnt = 0, needsum = 0;
pii res = query(prc, rc, mid+1, r, need);
need -= res.second;
needcnt += res.first;
needsum += res.second;
res = query(plc, lc, l, mid, need);
needcnt += res.first;
needsum += res.second;
return {needcnt, needsum};
}
bool ok(int x1, int y1, int x2, int y2, int tar) {
int res = 0;
for (int i = x1; i <= x2; i++) {
res += sum[i][y2] - sum[i][y1-1];
}
return res >= tar;
}
void init() {
cin >> n >> m >> Q;
for (int i = 1; i <= n; i++) {
sum[i] = arr[i] = root[i] = vector<int>(m+1,0);
root[i][0] = build(1, 1000);
for (int j = 1; j <= m; j++) {
int val; cin >> val;
arr[i][j] = val;
sum[i][j] = sum[i][j-1] + val;
root[i][j] = insert(root[i][j-1], 1, 1000, val);
}
}
}
int main() {
init();
while (Q--) {
int x1,y1,x2,y2,tar; cin >> x1 >> y1 >> x2 >> y2 >> tar;
if (!ok(x1,y1,x2,y2,tar)) cout << "Poor QLW" << "\n";
else {
vector<int> cur,pre;
for (int i = x1; i <= x2; i++) cur.push_back(root[i][y2]), pre.push_back(root[i][y1-1]);
pii ans = query(pre, cur, 1, 1000, tar);
cout << ans.first << "\n";
}
}
}
例5 洛谷P2633 Count on a tree
题意
给定一棵 $n$ 个节点的树,每个点 $i$ 具有权值 $a_i$。
有 $m$ 个询问,每次询问 $u ~ v ~ k$,回答 $u \text { xor } last$ 和 $v$ 的最短路径中,第 $k$ 小的点权。
其中,$last$ 为上次询问的答案,且保证每次询问均合法。
点权值的范围在 $[0, 2^{31}-1]$ 之间。
题解
树上主席树。
与普通主席树不同,我们主席树中的 前缀 代表了从 parent 继承而来的部分。即,我们选定 $1$ 作为根,那么一条从上到下的路径,就形成了一个主席树上的 前缀。
insert()
的时候,就有如下的代码:
int root_id = 0, root[maxn], ver_root[maxn]; // ver_root[u] 的值为 u 对应的版本的 root index
void dfs(int u, int p) {
root_id++;
ver_root[u] = root_id;
root[root_id] = insert(root[ver_root[p]], 1, N, rk[u]); // 从parent p那里继承而来
// ....
}
那么,怎么将 $u,v$ 之间的路径转化为 前缀之间的加减?(也就是线段树之间的加减)
回忆一下 树上启发式合并中,一道关于形成回文串路径的题目:CF741D
我们当时采用的是:
$$f_{u,v} = (f_u \text{ xor } f_x) \text{ xor } (f_v \text{ xor } f_x) = f_u \text{ xor } f_v$$
其中 $x = LCA(u,v)$
那么本题的思路也一样,将路径问题转化为 $LCA$ 问题。(实际上,本质就是 树上差分)
所以,令 $x = LCA(u,v)$,令 $f_u$ 为 $1 \rightarrow u$ 的路径对应的线段树。
我们发现在线段树上,$u,v$ 之间的路径 $f_{u,v}$,就是 $$f_u + f_v - f_x - f_{par(x)}$$
那么剩下的就是经典的主席树模版,区间第 $k$ 小问题了。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
struct Num {
int id, v;
} nums[maxn];
struct node {
int lc, rc, cnt;
} tr[maxn<<5];
int n, m, arr[maxn], rk[maxn], val[maxn], N = 0, id = 0;
int root[maxn];
int ecnt = 1, head[maxn], par[maxn][19], dep[maxn];
struct Edge {
int to, nxt;
} edges[maxn<<1];
void addEdge(int u, int v) {
Edge e = {v, head[u]};
edges[ecnt] = e;
head[u] = ecnt++;
}
int build(int l, int r) {
int cur = ++id;
if (l == r) return cur;
int mid = (l+r) >> 1;
tr[cur].lc = build(l, mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
int insert(int pre, int l, int r, int p) {
int cur = ++id;
tr[cur] = tr[pre];
tr[cur].cnt++;
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid)
tr[cur].lc = insert(tr[pre].lc, l, mid, p);
else
tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
return cur;
}
int query(int pre1, int pre2, int cur1, int cur2, int l, int r, int k) {
if (l == r) return l;
int plc1 = tr[pre1].lc, prc1 = tr[pre1].rc;
int plc2 = tr[pre2].lc, prc2 = tr[pre2].rc;
int lc1 = tr[cur1].lc, rc1 = tr[cur1].rc;
int lc2 = tr[cur2].lc, rc2 = tr[cur2].rc;
int mid = (l+r) >> 1;
int lcnt = tr[lc1].cnt + tr[lc2].cnt - tr[plc1].cnt - tr[plc2].cnt; // 线段树加减
if (k <= lcnt) return query(plc1, plc2, lc1, lc2, l, mid, k);
else return query(prc1, prc2, rc1, rc2, mid+1, r, k-lcnt);
}
int jump(int u, int d) {
int c = 0;
while (d) {
if (d&1) u = par[u][c];
d >>= 1, c++;
}
return u;
}
int LCA(int u, int v) {
if (dep[u] < dep[v]) swap(u,v);
int d = dep[u] - dep[v];
u = jump(u, d);
if (u == v) return u;
for (int j = 18; j >= 0; j--) {
if (par[u][j] != par[v][j]) u = par[u][j], v = par[v][j];
}
return par[u][0];
}
int root_id = 0, ver_root[maxn]; // ver_root[u] 的值为 u 对应的版本的 root index
void dfs(int u, int p) {
root_id++;
ver_root[u] = root_id;
root[root_id] = insert(root[ver_root[p]], 1, N, rk[u]);
dep[u] = dep[p] + 1;
par[u][0] = p;
for (int j = 1; j <= 18; j++) par[u][j] = par[par[u][j-1]][j-1];
for (int e = head[u]; e; e = edges[e].nxt) {
int to = edges[e].to;
if (to == p) continue;
dfs(to, u);
}
}
int Query(int u, int v, int k) {
int lca = LCA(u,v);
int res = query(root[ver_root[lca]], root[ver_root[par[lca][0]]], root[ver_root[u]], root[ver_root[v]], 1, N, k);
return val[res];
}
void init() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> arr[i], nums[i] = {i, arr[i]};
for (int i = 1; i <= n-1; i++) {
int u,v; cin >> u >> v;
addEdge(u, v); addEdge(v, u);
}
// 离散化
sort(nums + 1, nums + 1 + n, [](auto a, auto b) {
return a.v < b.v;
});
rk[nums[1].id] = ++N; val[1] = nums[1].v;
for (int i = 2; i <= n; i++) {
if (nums[i].v > nums[i-1].v) {
N++;
}
rk[nums[i].id] = N;
val[N] = nums[i].v;
}
root[0] = build(1, N);
dfs(1, 0);
}
int last = 0;
int main() {
fastio;
init();
while (m--) {
int u,v,k; cin >> u >> v >> k;
u ^= last;
last = Query(u,v,k);
cout << last << "\n";
}
}
例6 CF813E Army Creation
题意
给定 $n$ 个数字 $a_1,a_2,…,a_n$,给定一个数字 $k$,有 $q$ 次询问,每次询问:
$x ~ y$:令 $L = ((x+last) \text { mod } n) + 1, R = ((y+last) \text { mod } n) + 1$,回答 $[L,R]$ 中,最多可以选多少个数,使得任何一个数字选择的数字次数 $\leq k$?
题解
仍然是 强制在线,还是和例4的做法一样,维护 pr[]
数组:只不过维护的值改了一下:
$pr[i] = j$,满足 $arr[i] = arr[j]$,且 $i$ 往前走 $k$ 个位置,就得到 $j$ 。
然后每次询问就回答 $[L,R]$ 之中,有多少个 $i \in [L,R]$ 满足 $pr[i] < L$。
预处理这个
pr[]
数组,可以通过倍增的方式(类似于 LCA)来做。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
int n,k;
int par[maxn][19], pr[maxn], id = 0, root[maxn], last = 0, arr[maxn], pos[maxn];
struct node {
int lc,rc,cnt;
} tr[maxn<<5];
int build(int l, int r) {
int cur = ++id;
if (l == r) return cur;
int mid = (l+r) >> 1;
tr[cur].lc = build(l, mid);
tr[cur].rc = build(mid+1, r);
return cur;
}
int insert(int pre, int l, int r, int p) {
int cur = ++id;
tr[cur] = tr[pre];
tr[cur].cnt++;
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p);
else tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
return cur;
}
int query(int pre, int cur, int l, int r, int L) {
if (r < L) return tr[cur].cnt - tr[pre].cnt;
if (l >= L) return 0;
int mid = (l+r) >> 1;
int res = 0;
res += query(tr[pre].lc, tr[cur].lc, l, mid, L);
res += query(tr[pre].rc, tr[cur].rc, mid+1, r, L);
return res;
}
int jump(int u, int d) {
int c = 0;
while (d) {
if (d&1) u = par[u][c];
c++; d >>= 1;
}
return u;
}
void init() {
cin >> n >> k;
for (int i = 1; i <= n; i++) {
int val; cin >> val;
arr[i] = val;
par[i][0] = pos[val];
pos[val] = i;
}
root[0] = build(0, n);
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= 18; j++) {
par[i][j] = par[par[i][j-1]][j-1];
}
pr[i] = jump(i, k);
root[i] = insert(root[i-1], 0, n, pr[i]);
}
}
int main() {
init();
int Q; cin >> Q;
while (Q--) {
int l,r; cin >> l >> r;
l = (l + last) % n + 1;
r = (r + last) % n + 1;
if (l > r) swap(l,r);
last = query(root[l-1], root[r], 0, n, l);
cout << last << "\n";
}
}
例7 CF1404C Fixed Point Removal
题意
给定长度为 $n$ 的数组 $a_1,a_2,…,a_n$。
我们每次可以选择 $i$ 使得 $a_i = i$,然后将它删掉。删掉后,剩下的部分会合并在一起。
现在有 $q$ 个询问,每次询问 $(x,y)$,回答:
如果将 $a_1,a_2,…,a_x$ 赋值为 $n+1$,$a_n,a_{n-1},…,a_{n-y+1}$ 赋值为 $n+1$ (实际上,就是让 $a$ 的前 $x$ 个元素 和 后 $y$ 个元素变得无法删除。),我们能最多删除多少个元素?
询问之间完全独立,互不影响。
其中,$1 \leq n,q \leq 3 \times 10^5, a_i \in [1,n], x,y \geq 0, x+y < n$
题解
首先,我们先不考虑询问的问题。
对于原数组,如何求出最多删除的元素个数?
我们会发现,当我们删除一个元素时,它右边的所有元素的 index 都会减少 $1$,这可能导致右边的元素又出现了可删除的。
同时我们发现,如果对于 $i$,有 $a_i > i$,则无论怎么删除,$i$ 只减不增,这样 $a_i = i$ 永远不可能成立。
这意味着,当出现 $a_i = i$ 时,就要立刻删除,如果它左边的元素被删了,它就再也删不掉了。
所以对于原数组,最优的删除方法是:
从右侧开始删,删完一个数字后,判断右边有没有出现新的可删数字,如果有的话继续从右开始。
有了贪心的删除策略,我们不能简单的模拟删除过程,因为这是 $O(n^2)$ 的。
每次删除操作,我们需要对右边的元素的index 进行区间减 $1$,很明显可以用线段树来维护。
不妨将问题转化一下,令 $b_i = i - a_i$:
- 当 $b_i = 0$ 时,可以删除,删除时,将 $[i+1,n]$ 的所有 $b_i$ 都减去 $1$。
- 当 $b_i < 0$ 时,永远都不可能删除。
- 当 $b_i > 0$ 时,如果它未来的某时刻被减为 $0$ 了,则可以被删掉。
那我们在线段树里面维护一个 $min$,代表区间最小值。
如果一个元素 $< 0$,则直接把它赋值为 $10^9$,无论怎么减,它都不可能等于 $0$,代表着它是一个无效元素(永远不可能被删除,或者已经被删除)。
每次询问一下区间最小值的位置,如果最小值为 $0$,就把这个位置 $i$ 的元素删掉,然后将 $[i+1,n]$ 都减去 $1$,然后将 $b_i$ 设为 $10^9$(代表已删除)。
现在有了 无询问 情况下的解,有询问怎么办?
我们观察一下样例数组(长度为 $13$)
$[2,2,3 ,9 ,5 ,4 ,6 ,5 ,7, 8, 3, 11, 13]$
我们删除的 index 顺序(指原数组的index)是:$13,5,12,7,10,9,3,8,6,2,11$。
假设我们询问了 $x=3, y=1$,这意味着 $1,2,3$ 和 $13$ 都不能出现在这个删除序列中了(我们给它们 打上标记),并且它后面的 所有 比它大 的数字也被打上标记了(因为它实际上依赖前面的数字,但是前面的这些数字无法被删了)!如下:
找 $1$,$1$ 不在删除序列中,忽略。
找 $2$,$2$ 在删除序列中,后面比它大的数字有 $11$,所以 $2,11$ 打上标记。
找 $3$,$3$ 在删除序列中,后面比它大的数字有 $8,6,2,11$,所以 $3,8,6,2,11$ 打上标记。
找 $13$,$13$ 在删除序列中,后面不存在比它大的数字。所以 $13$ 打上标记。
以上,剩下没有被标记的,只有 $5,12,7,10,9$,共 $5$ 个数字,所以答案为 $5$。
将上述的模拟过程,总结一下就是:
设删除序列(就是上面的 $13,5,12,7,10,9,3,8,6,2,11$)为 $c$。
那么,每次询问 $(x,y)$,所有 未被标记 的index $i$,必须得满足以下的所有条件:
- $c_i \in [x+1, n-y]$
- $pre_i \geq x+1$
其中,$pre_i$ 代表 $\min \{ c_1,c_2,…,c_{i-1}\}$。这是因为,如果 $pre_i \leq x$,则 $i$ 前面必然存在一个数字被标记了,所以 $i$ 也要被标记。
现在问题就转化为:
给定一个长度为 $n$ 的数组,每个元素是 $(c, pre)$ 的形式。
每次询问 $(x,y)$,求有多少个 $i \in [1,n]$ 满足以下所有条件:
- $c_i \in [x+1, n-y]$
- $pre_i \geq x+1$
有离线和在线两种方法。
在线做法:
将数组根据 $c$ 的值,sort一下。
每次询问的时候,先用二分找到左边界 $l$,满足 $c_l \geq x+1$,且 $l$ 尽可能小。
再二分找到右边界 $r$,满足 $c_r \leq n-y$,且 $r$ 尽可能大。
然后问题就转化为,求 $[l,r]$ 之间,有多少个元素 $i$,满足:
- $i \in [l,r]$
- $pre_i \geq x+1$
那这就是一个标准的主席树问题了。
复杂度:$O(n\log^2n)$
离线做法:
将所有的询问,根据 $x$ 的值,从大到小 进行sort。
然后,将数组根据 $pre$ 的值,从大到小 进行sort。
所以,我们在回答每个询问 $(x,y)$ 的时候,我们只需要考虑 $pre_i \geq x+1$ 的部分。
也就是说我们可以开一个线段树,维护 $c_i$ 的值。
假如我们当前处理到了 询问 $(x,y)$,我们只在线段树内维护所有满足 $pre_i \geq x+1$ 的 $c_i$ 即可。
• 这本质上是一个,将数组内的元素,根据 $pre_i$ 的值,逐一插入到线段树中的过程。
复杂度:$O(n\log n)$
总结:
本题的思考流程分为以下几个步骤:
- 找到最优的删除策略。
- 对于无询问状态下,如何模拟删除过程。
- 对于有询问状态下,如何模拟删除过程。
- 如何处理一个 类似二维数点 的问题。
在线做法(主席树)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
int n,q;
struct node {
int idx, val;
} arr[maxn];
struct tree_node {
int minval, idx;
int lazy = 0;
} tr[maxn<<2];
void push_up(int cur) {
tr[cur].minval = min(tr[cur<<1].minval, tr[cur<<1|1].minval);
}
void push_down(int cur) {
if (tr[cur].lazy == 0) return;
int lazy = tr[cur].lazy;
tr[cur].lazy = 0;
int l = cur<<1, r = cur<<1|1;
tr[l].lazy += lazy, tr[r].lazy += lazy;
tr[l].minval += lazy, tr[r].minval += lazy;
}
void build(int cur, int l, int r) {
if (l == r) {
tr[cur].minval = arr[l].val;
tr[cur].idx = arr[l].idx;
return;
}
int mid = (l+r) >> 1;
build(cur<<1, l, mid);
build(cur<<1|1, mid+1, r);
push_up(cur);
}
void update(int cur, int l, int r, int L, int R, int x) {
if (L > R) return;
if (l >= L && r <= R) {
tr[cur].lazy += x;
tr[cur].minval += x;
return;
}
int mid = (l+r) >> 1;
push_down(cur);
if (L <= mid) update(cur<<1, l, mid, L, R, x);
if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
// query if [1,n] has minval = 0, (query right first), if yes, return the idx (original)
int query(int cur, int l, int r) {
if (tr[cur].minval != 0) return -1;
if (l == r) {
return tr[cur].idx;
}
push_down(cur);
int mid = (l+r) >> 1;
if (tr[cur<<1|1].minval == 0) return query(cur<<1|1, mid+1, r);
return query(cur<<1, l, mid);
}
// delete the element in p
void del(int cur, int l, int r, int p) {
if (l == r) {
tr[cur].minval = 1e9;
return;
}
int mid = (l+r) >> 1;
push_down(cur);
if (p <= mid) del(cur<<1, l, mid, p);
if (p > mid) del(cur<<1|1, mid+1, r, p);
push_up(cur);
}
struct Num {
int val, pre;
} seq[maxn];
int tail = 0;
void init() {
build(1, 1, n);
while (1) {
int p = query(1, 1, n);
if (p == -1) break;
del(1, 1, n, p);
seq[++tail] = {p, (int)1e9};
update(1, 1, n, p+1, n, -1);
}
if (!tail) return;
for (int i = 2; i <= tail; i++) {
seq[i].pre = min(seq[i-1].pre, seq[i-1].val);
}
sort(seq+1, seq+tail+1, [](auto a, auto b) {
return a.val < b.val; // 根据 c 的值先进行sort
});
}
void debug() {
printf("tail = %d\n",tail);
for (int i = 1; i <= tail; i++) {
printf("%d ", seq[i].val);
}
printf("\n");
}
struct persistent_tree_node {
int lc, rc, cnt;
} ptr[maxn<<5];
int root[maxn], id = 0;
int insert(int pre, int l, int r, int p) {
int cur = ++id;
ptr[cur] = ptr[pre];
ptr[cur].cnt++;
if (l == r) return cur;
int mid = (l+r) >> 1;
if (p <= mid) ptr[cur].lc = insert(ptr[pre].lc, l, mid, p);
if (p > mid) ptr[cur].rc = insert(ptr[pre].rc, mid+1, r, p);
return cur;
}
int query(int cur, int pre, int l, int r, int L, int R) {
if (l >= L && r <= R) return ptr[cur].cnt - ptr[pre].cnt;
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += query(ptr[cur].lc, ptr[pre].lc, l, mid, L, R);
if (R > mid) res += query(ptr[cur].rc, ptr[pre].rc, mid+1, r, L, R);
return res;
}
// return the smallest index, where seq[i].val >= x
int search_down(int x) {
int l = 1, r = tail;
int ans = tail+1;
while (l <= r) {
int mid = (l+r) >> 1;
if (seq[mid].val >= x) {
ans = mid;
r = mid-1;
} else l = mid+1;
}
return ans;
}
// return the largest index, where seq[i].val <= x
int search_up(int x) {
int l = 1, r = tail;
int ans = 0;
while (l <= r) {
int mid = (l+r) >> 1;
if (seq[mid].val <= x) {
ans = mid;
l = mid+1;
} else r = mid-1;
}
return ans;
}
int main() {
cin >> n >> q;
for (int i = 1; i <= n; i++) {
cin >> arr[i].val;
arr[i].idx = i;
int val = arr[i].val, idx = i;
if (val > idx) arr[i].val = 1e9;
else arr[i].val = i - val;
}
init();
for (int i = 1; i <= tail; i++) {
root[i] = insert(root[i-1], 1, n, seq[i].pre);
}
while (q--) {
int x,y; cin >> x >> y;
int ans = 0;
int l = search_down(x+1);
int r = search_up(n-y);
if (l > r) ans = 0;
else {
ans = query(root[r], root[l-1], 1, n, x+1, n);
}
// 以下是暴力的做法:
// for (int i = 1; i <= tail; i++) {
// if (seq[i].val >= x+1 && seq[i].val <= n-y && seq[i].pre >= x+1) ans++;
// }
cout << ans << "\n";
}
}
离线做法
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
int n,q;
struct node {
int idx, val;
} arr[maxn];
struct tree_node {
int minval, idx;
int lazy = 0;
} tr[maxn<<2];
void push_up(int cur) {
tr[cur].minval = min(tr[cur<<1].minval, tr[cur<<1|1].minval);
}
void push_down(int cur) {
if (tr[cur].lazy == 0) return;
int lazy = tr[cur].lazy;
tr[cur].lazy = 0;
int l = cur<<1, r = cur<<1|1;
tr[l].lazy += lazy, tr[r].lazy += lazy;
tr[l].minval += lazy, tr[r].minval += lazy;
}
void build(int cur, int l, int r) {
if (l == r) {
tr[cur].minval = arr[l].val;
tr[cur].idx = arr[l].idx;
return;
}
int mid = (l+r) >> 1;
build(cur<<1, l, mid);
build(cur<<1|1, mid+1, r);
push_up(cur);
}
void update(int cur, int l, int r, int L, int R, int x) {
if (L > R) return;
if (l >= L && r <= R) {
tr[cur].lazy += x;
tr[cur].minval += x;
return;
}
int mid = (l+r) >> 1;
push_down(cur);
if (L <= mid) update(cur<<1, l, mid, L, R, x);
if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
// query if [1,n] has minval = 0, (query right first), if yes, return the idx (original)
int query(int cur, int l, int r) {
if (tr[cur].minval != 0) return -1;
if (l == r) {
return tr[cur].idx;
}
push_down(cur);
int mid = (l+r) >> 1;
if (tr[cur<<1|1].minval == 0) return query(cur<<1|1, mid+1, r);
return query(cur<<1, l, mid);
}
// delete the element in p
void del(int cur, int l, int r, int p) {
if (l == r) {
tr[cur].minval = 1e9;
return;
}
int mid = (l+r) >> 1;
push_down(cur);
if (p <= mid) del(cur<<1, l, mid, p);
if (p > mid) del(cur<<1|1, mid+1, r, p);
push_up(cur);
}
struct Num {
int val, pre;
} seq[maxn];
int tail = 0;
void init() {
build(1, 1, n);
while (1) {
int p = query(1, 1, n);
if (p == -1) break;
del(1, 1, n, p);
seq[++tail] = {p, (int)1e9};
update(1, 1, n, p+1, n, -1);
}
if (!tail) return;
for (int i = 2; i <= tail; i++) {
seq[i].pre = min(seq[i-1].pre, seq[i-1].val);
}
sort(seq+1, seq+tail+1, [](auto a, auto b) {
return a.pre > b.pre; // 根据 pre 进行 sort
});
}
struct Query_node {
int x,y,id;
} que[maxn];
int ans[maxn];
struct tree_node2 {
int cnt;
} tr2[maxn<<2];
void insert(int cur, int l, int r, int p) {
tr2[cur].cnt++;
if (l == r) {
return;
}
int mid = (l+r) >> 1;
if (p <= mid) insert(cur<<1, l, mid, p);
else insert(cur<<1|1, mid+1, r, p);
}
int Query(int cur, int l, int r, int L, int R) {
if (L > R) return 0;
if (l >= L && r <= R) return tr2[cur].cnt;
int mid = (l+r) >> 1;
int res = 0;
if (L <= mid) res += Query(cur<<1, l, mid, L, R);
if (R > mid) res += Query(cur<<1|1, mid+1, r, L, R);
return res;
}
int main() {
cin >> n >> q;
for (int i = 1; i <= n; i++) {
cin >> arr[i].val;
arr[i].idx = i;
int val = arr[i].val, idx = i;
if (val > idx) arr[i].val = 1e9;
else arr[i].val = i - val;
}
init();
for (int i = 1; i <= q; i++) {
cin >> que[i].x >> que[i].y;
que[i].id = i;
}
sort(que+1, que+q+1, [](auto a, auto b){
return a.x > b.x;
});
int p = 0;
for (int i = 1; i <= q; i++) {
int x = que[i].x, y = que[i].y, id = que[i].id;
while (p+1 <= tail && seq[p+1].pre >= x+1) { // 满足 pre >= x+1 就插入
p++;
insert(1, 1, n, seq[p].val); // 逐一插入进线段树
}
ans[id] = Query(1, 1, n, x+1, n-y);
}
for (int i = 1; i <= q; i++) cout << ans[i] << "\n";
}