介绍

主席树全名叫做 可持久化权值线段树,一般用于一个数组上,有以下的功能:

  1. 对于每一个区间 都能 开一个权值线段树
  2. 能够维护数组的 历史版本。(仅用于单点插入/修改)

思想

节点的复制

主席树的主要思想在于,对于在线段树上的单点修改,如果要维护多个版本(修改前和修改后),我们可以复制出新的节点,来维护新版本的信息。

由于单点修改仅会影响一条链(从叶子节点,一直到根节点),所以每个版本最多会复制出 $O(\log n)$ 个新节点。

image

如上图,橙色部分就是一个新版本,复制出来了一条链。

• 因为复制节点,所以也需要 动态开点

来一道例题:

例1 洛谷 P3919 【模板】可持久化线段树 1(可持久化数组)

题意

维护一个长度为 $N$ 的数组,共有 $M$ 次询问。询问格式如下:

$v ~ 1 ~ p ~ x$:在版本 $v$ 的基础上,将 $a_p$ 修改为 $x$

$v ~ 2 ~ p$:在版本 $v$ 的基础上,询问 $a_p$ 的值

每次询问后,都生成一个新版本。(所以共有 $M+1$ 个版本)

题解

可持久化的操作在上面说过了。对于每一次修改,都 复制一条链。如果是操作 $2$,复制根节点就可以了。

可持久化的复制节点方式和普通的动态开点略有不同,主要体现在:

  1. 不需要看 cur == 0 与否,直接复制即可,并且将复制后的编号返回。
  2. 需要记录 上一个版本同位置 节点的标号 pre
  3. 需要 build 操作 root[0] = build(1,n),因为初始状态无论是否为空,都需要把线段树开好,否则后面无法复制(也无法进行相减操作)。

访问不同版本的线段树时,就访问它们的 root 即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+5;
const int maxm = 3e7;

int n,m,root[maxn],id, arr[maxn];
struct node {
    int lc,rc,val;
} tr[maxm];

int build(int l, int r) {  // 参数中没有 cur
    int cur = ++id;  // 直接添加
    if (l == r) {
        tr[cur].val = arr[l];
        return cur;
    }
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l, mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

// 将位置p 的值修改为 x
// pre 是前一个版本的 同位置节点
int insert(int pre, int l, int r, int p, int x) {  
    int cur = ++id;
    tr[cur] = tr[pre];
    if (l == r) {
        tr[cur].val = x;
        return cur;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p, x);
    if (p > mid) tr[cur].rc = insert(tr[pre].rc, mid+1, r, p, x);
    return cur;
}

int query(int cur, int l, int r, int p) {
    if (l == r) return tr[cur].val;
    int mid = (l+r) >> 1;
    if (p <= mid) return query(tr[cur].lc, l, mid, p);
    else return query(tr[cur].rc, mid+1, r, p);
}

void init() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    root[0] = build(1,n);  // 注意需要 build(1,n),这是版本 0
}

int main() {
    init();
    for (int i = 1; i <= m; i++) {
        int v, op, p;
        cin >> v >> op >> p;
        if (op == 1) {
            int x; cin >> x;
            root[i] = insert(root[v], 1, n, p, x);
        } else {
            int res = query(root[v], 1, n, p);  // 版本 v 的根
            cout << res << "\n";
            root[i] = root[v];
        }
    }
}

区间询问(两个线段树相减)

区间问题常用的一个思想是 前缀和

那么对于一个数组 $a_1,a_2,…,a_n$,我们可以维护 $n+1$ 个版本的权值线段树。分别维护了 $sum_0, sum_1, sum_2, …, sum_n$ 的信息。

例如,$sum_0$ 是一个空的权值线段树(已经 build() 过的),$sum_3$ 这个权值线段树维护的就是 $a_1,a_2,a_3$ 这个数组的信息。

要求 $sum_3$,我们在 $sum_2$ 的基础上,将 $a_3$ 的信息加进 $sum_2$(单点修改),形成一个新版本的权值线段树即可。

那么,如果我们要求 $[L,R]$ 这个区间对应的权值线段树,只要求出 $sum_r - sum_{l-1}$ 对应的权值线段树就可以了!

线段树之间怎么相减?把对应节点维护的值相减一下即可!

image


以上就是主席树的全部内容了,本质上是 节点复制 + 线段树相减

这样,对于每一个区间,都可以获得一个权值线段树

例题

例2 洛谷P3834 【模板】可持久化线段树 2(主席树)

题意

给定 $N$ 个整数 $a_1,a_2,…,a_n$,和 $m$ 个询问,每次询问 $[L,R]$ 之间的第 $k$ 小值。保证询问合法。

其中,$1 \leq n,m \leq 2 \times 10^5, |a_i| \leq 10^9$

题解

思考一个问题能否用主席树,我们可以先思考,对于整个数组,我们能否用权值线段树解决

答案是可以的!如果我们要求整个数组的第 $k$ 小,可以将所有数字先离散化成排名,然后用权值线段树来维护各个排名的数量。求可以求出第 $k$ 小了!

所以主席树可以解决,步骤如下:

  1. 对整个数组进行离散化
  2. 维护主席树,对于每一个 $[L,R]$ 都可以获得一个权值线段树,然后可以求得第 $k$ 小。

注:离散化的一个很方便的写法是 struct + sort(),然后遍历 sort 后的数组:

int arr[maxn];
int N = 0, rk[maxn], val[maxn];  
// N: 排名数,rk[i]: arr[i]的排名,val[i]: 排名为i的数字的值
void init() {
    for (int i = 1; i <= n; i++) cin >> arr[i].val, arr[i].id = i;
    sort(arr+1, arr+n+1, [](auto a, auto b) {
        return a.val < b.val;
    });

    rk[arr[1].id] = ++N;
    val[1] = arr[1].val;
    for (int i = 2; i <= n; i++) {
        if (arr[i].val > arr[i-1].val) N++;
        rk[arr[i].id] = N;
        val[N] = arr[i].val;
    }
}
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 1e7;

int n,m;
struct Num {
    int val, id;
} arr[maxn];
int rk[maxn], val[maxn];  // rk[1] 代表 arr[1] 离散化后的值(排名), val[1] 代表整个array中第1小的值
int N = 0;

struct node {
    int lc, rc, cnt;
} tr[maxm];
int root[maxn], id = 0;

int build(int l, int r) {
    int cur = ++id;
    if (l == r) {
        return cur;
    }
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l,mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

// pre 是上个版本的节点
// 令 p 位置的 cnt += 1
int insert(int pre, int l, int r, int p) {
    int cur = ++id;
    tr[cur] = tr[pre];  // 复制一份上个版本
    tr[cur].cnt++;  // 添加了一个节点
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid) tr[cur].lc = insert(tr[cur].lc, l, mid, p);
    if (p > mid) tr[cur].rc = insert(tr[cur].rc, mid+1, r, p);
    return cur;
}


// 查询 tr[pre,cur] 之间的第k小,返回具体的值
int query(int pre, int cur, int l, int r, int k) {
    if (l == r) return val[l];
    int prelc = tr[pre].lc, lc = tr[cur].lc;
    int prerc = tr[pre].rc, rc = tr[cur].rc;
    int mid = (l+r) >> 1;

    int lcnt = tr[lc].cnt - tr[prelc].cnt;
    if (lcnt >= k) return query(prelc, lc, l, mid, k);  // 如果左边有 >= k 个数
    else return query(prerc, rc, mid+1, r, k-lcnt);  // 否则只看右边
}

void init() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> arr[i].val, arr[i].id = i;
    sort(arr+1, arr+n+1, [](auto a, auto b) {
        return a.val < b.val;
    });

    rk[arr[1].id] = ++N;
    val[1] = arr[1].val;
    for (int i = 2; i <= n; i++) {
        if (arr[i].val > arr[i-1].val) N++;
        rk[arr[i].id] = N;
        val[N] = arr[i].val;
    }

    root[0] = build(1, N);
    for (int i = 1; i <= n; i++) {
        root[i] = insert(root[i-1], 1, N, rk[i]);
    }
}

int main() {
    init();
    for (int i = 1; i <= m; i++) {
        int l,r,k;
        cin >> l >> r >> k;
        int ans = query(root[l-1], root[r], 1, N, k);
        cout << ans << "\n";
    }
}

例3 CF1422F Boring Queries

题意

给定 $n$ 个正整数 $a_1, a_2, … , a_n$,$q$ 次询问。

每次询问 $[x,y]$,为了保证询问在线,记上一次询问的答案为 $last$,然后令 $L = ((x+last) \text{ mod } n) + 1$,$R = ((y+last) \text{ mod } n) + 1$,如果 $L > R$,则交换 $L,R$。

每次询问,回答 $[L,R]$ 之间所有数的 LCM,答案对 $10^9+7$ 取模。

其中,$1 \leq n,q \leq 10^5, 1 \leq a_i \leq 2 \times 10^5, 1 \leq x,y \leq 10^5$

题解

本题主要有两个难点:

  1. 所有询问在线。
  2. $LCM$ 数字极大,且需要取模,无法正常维护。

因为 $LCM$ 的值极大,且取模,可以考虑 质因数分解

但是,$a_i \leq 2 \times 10^5$,我们无法直接维护每一个质因子。

这时候可以考虑 根号分治,将质因子分为两部分:

一部分是 $\leq \sqrt {(2 \times 10^5)}$ 的质因子。

还有一部分是 $> \sqrt {(2 \times 10^5)}$ 的质因子。


对于 $\leq \sqrt {(2 \times 10^5)}$ 的质因子,我们会发现这种小因子只有 $87$ 个。所以我们只要维护 $87$ 个 ST表 来维护每个小因子在区间内的最大次数即可。


对于 $> \sqrt {(2 \times 10^5)}$ 的质因子,我们会发现它们的出现次数最多为 $1$,并且对于任意一个 $a_i$,它最多只能包含一个这样的大因子。

所以对于大因子而言,求 $LCM$ 就转化为:

求一个区间内,有哪些不同的大因子出现过,将这些 unique 的大因子乘起来就可以了!


那么,如何解决如下的问题?

给定一个数组,询问一个区间,求该区间内 所有不同的数的乘积

这个问题,我们在 HH的项链 中见到过。

但是这个题 强制在线,没法用上面的离线方法来解决。

所以,我们维护一个 pr[] 数组,其中 pr[i] 代表:对于 i,上一个值等于 arr[i] 的 index 的值。

(即:pr[i] = j,其中 arr[j] = arr[i], j < i

在询问 $[L,R]$ 时,我们将问题转化为:

求区间内所有 arr[i] 的乘积,使得:

  1. $i \in [L,R]$

  2. pr[i] < L

这个问题可以用 主席树 解决。我们根据 pr[i] 的值来建主席树,树上节点的值就维护乘积。线段树相减 就用 乘积的逆元 来实现。


最后总结一下本题的步骤:

  1. 欧拉筛求出所有质数。
  2. 维护 $87$ 个小于等于 $450$ 的小质因子,建立 $87$ 个ST表。
  3. 建好大因子的主席树。

本题空间卡的非常紧,我们将 ST表 的数组改为了 short 类型,可以避免 MLE。

代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e5+2;
const int maxm = 2e5+2;

vector<int> primes;
bool is_prime[maxm];
int n, arr[maxn];
void euler() {
    fill(is_prime, is_prime+maxm, 1);
    for (int cur = 2; cur < maxm; cur++) {
        if (is_prime[cur]) primes.push_back(cur);
        for (auto p : primes) {
            if (p * cur >= maxm) break;
            is_prime[p*cur] = 0;
            if (cur % p == 0) break;
        }
    }
}

short st[87][maxn][18];
short bin[maxn];

int ask_st(int i, int l, int r) {
    int len = (r-l+1);
    int k = bin[len];
    return max(st[i][l][k], st[i][r-(1<<k)+1][k]);
}

void build_st() {
    bin[1] = 0;
    bin[2] = 1;
    for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
    for (int i = 0; i <= 86; i++) {
        int p = primes[i];
        for (int j = 1; j <= n; j++) {
            int cnt = 0;
            while (arr[j] % p == 0) {
                arr[j] /= p, cnt++;
            }
            st[i][j][0] = cnt;
        }

        for (int k = 1; k < 18; k++) {
            for (int j = 1; j+(1<<k)-1 <= n; j++) {
                st[i][j][k] = max(st[i][j][k-1], st[i][j+(1<<(k-1))][k-1]);
            }
        }
    }
}

struct node {
    int lc,rc;
    ll val = 1;
} tr[maxn<<5];
int id, root[maxn], pr[maxn];

inline ll qpow(ll a, ll b) {
    if (!b) return 1;
    ll res = 1;
    while (b) {
        if (b&1)
            (res *= a) %= mod;
        (a *= a) %= mod;
        b >>= 1;
    }
    return res;
}
inline ll inv(ll a) {
    return qpow(a, mod-2);
}

int build(int l, int r) {
    int cur = ++id;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l,mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

void push_up(int cur) {
    int lc = tr[cur].lc, rc = tr[cur].rc;
    tr[cur].val = (tr[lc].val * tr[rc].val) % mod;
}

int insert(int pre, int l, int r, int p, ll x) {
    int cur = ++id;
    tr[cur] = tr[pre];
    (tr[cur].val *= x) %= mod;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid)
        tr[cur].lc = insert(tr[pre].lc, l, mid, p, x);
    else
        tr[cur].rc = insert(tr[pre].rc, mid+1, r, p, x);
    push_up(cur);
    return cur;
}

// 询问 [pre, cur] 之间的大质数之积, 且保证 pr[i] < p
ll query(int pre, int cur, int l, int r, int p) {
    if (r < p) return (tr[cur].val * inv(tr[pre].val)) % mod;
    if (l >= p) return 1;
    int mid = (l+r) >> 1;
    ll res = 1;
    (res *= query(tr[pre].lc, tr[cur].lc, l, mid, p)) %= mod;
    (res *= query(tr[pre].rc, tr[cur].rc, mid+1, r, p)) %= mod;
    return res;
}


int pos[maxm];
void init() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    euler();
    build_st();
    root[0] = build(0,n);
    for (int i = 1; i <= n; i++) {
        int val = arr[i];
        pr[i] = pos[val];
        pos[val] = i;
    }
    for (int i = 1; i <= n; i++) {
        root[i] = insert(root[i-1], 0, n, pr[i], arr[i]);
    }
}
ll last = 0;

ll Query(ll L, ll R) {
    L += last, R += last; L %= n, R %= n; L++, R++;
    if (L > R) swap(L,R);
    ll big = query(root[L-1], root[R], 0, n, L);
    ll small = 1;
    for (int i = 0; i <= 86; i++) {
        ll p = primes[i];
        int c = ask_st(i,L,R);
        (small *= qpow(p,c)) %= mod;
    }
    last = (small * big) % mod;
    return last;
}

int main() {
    init();
    int q; cin >> q;
    while (q--) {
        ll L,R; cin >> L >> R;
        cout << Query(L,R) << "\n";
    }
}

例4 洛谷P2468 [SDOI2010]粟粟的书架

题意

给定一个 $R \times C$ 的矩阵,矩阵中所有元素均为正整数,有 $M$ 个询问,每次询问:

$x_1 ~ y_1 ~ x_2 ~ y_2 ~ H$:求 $(x_1,y_1)$ 和 $(x_2,y_2)$ 之间的矩形中,最少取多少个数字可以让数字之和 $\geq H$?

数据范围:

对于 $50$ % 的数据,有 $R,C \leq 200, M \leq 2 \times 10^5$

对于另外 $50$ % 的数据,有 $R = 1, C \leq 5 \times 10^5, M \leq 2 \times 10^4$

矩阵中所有元素满足值在 $[1, 1000]$ 之间,$H \leq 2 \times 10^9$

题解

我们需要让数字之和尽量大,那么每次询问就选择最大的那些数。

可以维护主席树,以元素的值作为权值,节点之中维护 $cnt$ 和 $sum$。

对于 $R = 1$ 的情况很好解决。那么对于 $R \leq 200$ 呢?

有两种方法,

法一:维护 $200$ 棵主席树

询问的过程中,把每一行的线段树都进行 相加 (具体实现通过维护 vector<int> pre, vector<int> cur)。


法二:维护二维前缀和

sum[i][j][k] 代表 $(1,1)$ 和 $(i,j)$ 之间矩阵之中,数值 $\geq k$ 的数字的和

cnt[i][j][k] 代表 $(1,1)$ 和 $(i,j)$ 之间矩阵之中,数值 $\geq k$ 的数字的数量

询问的时候,二分一下 $k$ 就可以了。

以下给出主席树代码。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;

struct node {
    int lc, rc, sum, cnt;
} tr[maxn<<5];

int n,m,Q, id = 0;
vector<int> sum[201];
vector<int> arr[201];
vector<int> root[201];

int build(int l, int r) {
    int cur = ++id;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l, mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

int insert(int pre, int l, int r, int p) {
    int cur = ++id;
    tr[cur] = tr[pre];
    tr[cur].sum += p;
    tr[cur].cnt++;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p);
    else tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
    return cur;
}

pii query(vector<int>& pre, vector<int>& cur, int l, int r, int need) {
    if (need <= 0) return {0,0};
    int allsum = 0, allcnt = 0;
    for (int i = 0; i < pre.size(); i++) {
        int p = pre[i], c = cur[i];
        allsum += tr[c].sum - tr[p].sum;  // 200棵线段树相加
        allcnt += tr[c].cnt - tr[p].cnt;
    }
    if (allsum <= need) return {allcnt, allsum};
    if (l == r) return {(need/l + (need % l > 0)), (need/l + (need % l > 0)) * l};
    int mid = (l+r) >> 1;

    vector<int> plc, lc, prc, rc;
    for (int i = 0; i < pre.size(); i++) {
        int p = pre[i], c = cur[i];
        plc.push_back(tr[p].lc);
        lc.push_back(tr[c].lc);
        prc.push_back(tr[p].rc);
        rc.push_back(tr[c].rc);
    }

    int needcnt = 0, needsum = 0;

    pii res = query(prc, rc, mid+1, r, need);
    need -= res.second;
    needcnt += res.first;
    needsum += res.second;

    res = query(plc, lc, l, mid, need);
    needcnt += res.first;
    needsum += res.second;

    return {needcnt, needsum};
}

bool ok(int x1, int y1, int x2, int y2, int tar) {
    int res = 0;
    for (int i = x1; i <= x2; i++) {
        res += sum[i][y2] - sum[i][y1-1];
    }
    return res >= tar;
}

void init() {
    cin >> n >> m >> Q;
    for (int i = 1; i <= n; i++) {
        sum[i] = arr[i] = root[i] = vector<int>(m+1,0);
        root[i][0] = build(1, 1000);
        for (int j = 1; j <= m; j++) {
            int val; cin >> val;
            arr[i][j] = val;
            sum[i][j] = sum[i][j-1] + val;
            root[i][j] = insert(root[i][j-1], 1, 1000, val);
        }
    }
}

int main() {

    init();
    while (Q--) {
        int x1,y1,x2,y2,tar; cin >> x1 >> y1 >> x2 >> y2 >> tar;
        if (!ok(x1,y1,x2,y2,tar)) cout << "Poor QLW" << "\n";
        else {
            vector<int> cur,pre;
            for (int i = x1; i <= x2; i++) cur.push_back(root[i][y2]), pre.push_back(root[i][y1-1]);
            pii ans = query(pre, cur, 1, 1000, tar);
            cout << ans.first << "\n";
        }
    }
}

例5 洛谷P2633 Count on a tree

题意

给定一棵 $n$ 个节点的树,每个点 $i$ 具有权值 $a_i$。

有 $m$ 个询问,每次询问 $u ~ v ~ k$,回答 $u \text { xor } last$ 和 $v$ 的最短路径中,第 $k$ 小的点权。

其中,$last$ 为上次询问的答案,且保证每次询问均合法。

点权值的范围在 $[0, 2^{31}-1]$ 之间。

题解

树上主席树

与普通主席树不同,我们主席树中的 前缀 代表了从 parent 继承而来的部分。即,我们选定 $1$ 作为根,那么一条从上到下的路径,就形成了一个主席树上的 前缀

insert() 的时候,就有如下的代码:

int root_id = 0, root[maxn], ver_root[maxn];   // ver_root[u] 的值为 u 对应的版本的 root index
void dfs(int u, int p) {
    root_id++;
    ver_root[u] = root_id;
    root[root_id] = insert(root[ver_root[p]], 1, N, rk[u]);  // 从parent p那里继承而来
    // ....
}

那么,怎么将 $u,v$ 之间的路径转化为 前缀之间的加减?(也就是线段树之间的加减)

回忆一下 树上启发式合并中,一道关于形成回文串路径的题目:CF741D

我们当时采用的是:

$$f_{u,v} = (f_u \text{ xor } f_x) \text{ xor } (f_v \text{ xor } f_x) = f_u \text{ xor } f_v$$

其中 $x = LCA(u,v)$


那么本题的思路也一样,将路径问题转化为 $LCA$ 问题。(实际上,本质就是 树上差分

所以,令 $x = LCA(u,v)$,令 $f_u$ 为 $1 \rightarrow u$ 的路径对应的线段树。

我们发现在线段树上,$u,v$ 之间的路径 $f_{u,v}$,就是 $$f_u + f_v - f_x - f_{par(x)}$$

那么剩下的就是经典的主席树模版,区间第 $k$ 小问题了。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;

struct Num {
    int id, v;
} nums[maxn];

struct node {
    int lc, rc, cnt;
} tr[maxn<<5];

int n, m, arr[maxn], rk[maxn], val[maxn], N = 0, id = 0;
int root[maxn];
int ecnt = 1, head[maxn], par[maxn][19], dep[maxn];

struct Edge {
    int to, nxt;
} edges[maxn<<1];

void addEdge(int u, int v) {
    Edge e = {v, head[u]};
    edges[ecnt] = e;
    head[u] = ecnt++;
}

int build(int l, int r) {
    int cur = ++id;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l, mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

int insert(int pre, int l, int r, int p) {
    int cur = ++id;
    tr[cur] = tr[pre];
    tr[cur].cnt++;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid)
        tr[cur].lc = insert(tr[pre].lc, l, mid, p);
    else
        tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
    return cur;
}

int query(int pre1, int pre2, int cur1, int cur2, int l, int r, int k) {
    if (l == r) return l;
    int plc1 = tr[pre1].lc, prc1 = tr[pre1].rc;
    int plc2 = tr[pre2].lc, prc2 = tr[pre2].rc;
    int lc1 = tr[cur1].lc, rc1 = tr[cur1].rc;
    int lc2 = tr[cur2].lc, rc2 = tr[cur2].rc;

    int mid = (l+r) >> 1;

    int lcnt = tr[lc1].cnt + tr[lc2].cnt - tr[plc1].cnt - tr[plc2].cnt;  // 线段树加减

    if (k <= lcnt) return query(plc1, plc2, lc1, lc2, l, mid, k);
    else return query(prc1, prc2, rc1, rc2, mid+1, r, k-lcnt);
}

int jump(int u, int d) {
    int c = 0;
    while (d) {
        if (d&1) u = par[u][c];
        d >>= 1, c++;
    }
    return u;
}

int LCA(int u, int v) {
    if (dep[u] < dep[v]) swap(u,v);
    int d = dep[u] - dep[v];
    u = jump(u, d);
    if (u == v) return u;
    for (int j = 18; j >= 0; j--) {
        if (par[u][j] != par[v][j]) u = par[u][j], v = par[v][j];
    }
    return par[u][0];
}

int root_id = 0, ver_root[maxn];   // ver_root[u] 的值为 u 对应的版本的 root index
void dfs(int u, int p) {
    root_id++;
    ver_root[u] = root_id;
    root[root_id] = insert(root[ver_root[p]], 1, N, rk[u]);

    dep[u] = dep[p] + 1;
    par[u][0] = p;
    for (int j = 1; j <= 18; j++) par[u][j] = par[par[u][j-1]][j-1];

    for (int e = head[u]; e; e = edges[e].nxt) {
        int to = edges[e].to;
        if (to == p) continue;
        dfs(to, u);
    }
}

int Query(int u, int v, int k) {
    int lca = LCA(u,v);
    int res = query(root[ver_root[lca]], root[ver_root[par[lca][0]]], root[ver_root[u]], root[ver_root[v]], 1, N, k);
    return val[res];
}

void init() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> arr[i], nums[i] = {i, arr[i]};
    for (int i = 1; i <= n-1; i++) {
        int u,v; cin >> u >> v;
        addEdge(u, v); addEdge(v, u);
    }

    // 离散化
    sort(nums + 1, nums + 1 + n, [](auto a, auto b) {
        return a.v < b.v;
    });
    rk[nums[1].id] = ++N; val[1] = nums[1].v;
    for (int i = 2; i <= n; i++) {
        if (nums[i].v > nums[i-1].v) {
            N++;
        }
        rk[nums[i].id] = N;
        val[N] = nums[i].v;
    }
    root[0] = build(1, N);
    dfs(1, 0);
}

int last = 0;
int main() {
    fastio;
    init();

    while (m--) {
        int u,v,k; cin >> u >> v >> k;
        u ^= last;
        last = Query(u,v,k);
        cout << last << "\n";
    }
}

例6 CF813E Army Creation

题意

给定 $n$ 个数字 $a_1,a_2,…,a_n$,给定一个数字 $k$,有 $q$ 次询问,每次询问:

$x ~ y$:令 $L = ((x+last) \text { mod } n) + 1, R = ((y+last) \text { mod } n) + 1$,回答 $[L,R]$ 中,最多可以选多少个数,使得任何一个数字选择的数字次数 $\leq k$?

题解

仍然是 强制在线,还是和例4的做法一样,维护 pr[] 数组:只不过维护的值改了一下:

$pr[i] = j$,满足 $arr[i] = arr[j]$,且 $i$ 往前走 $k$ 个位置,就得到 $j$ 。

然后每次询问就回答 $[L,R]$ 之中,有多少个 $i \in [L,R]$ 满足 $pr[i] < L$。

预处理这个 pr[] 数组,可以通过倍增的方式(类似于 LCA)来做。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;

int n,k;
int par[maxn][19], pr[maxn], id = 0, root[maxn], last = 0, arr[maxn], pos[maxn];

struct node {
    int lc,rc,cnt;
} tr[maxn<<5];

int build(int l, int r) {
    int cur = ++id;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    tr[cur].lc = build(l, mid);
    tr[cur].rc = build(mid+1, r);
    return cur;
}

int insert(int pre, int l, int r, int p) {
    int cur = ++id;
    tr[cur] = tr[pre];
    tr[cur].cnt++;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid) tr[cur].lc = insert(tr[pre].lc, l, mid, p);
    else tr[cur].rc = insert(tr[pre].rc, mid+1, r, p);
    return cur;
}

int query(int pre, int cur, int l, int r, int L) {
    if (r < L) return tr[cur].cnt - tr[pre].cnt;
    if (l >= L) return 0;
    int mid = (l+r) >> 1;
    int res = 0;
    res += query(tr[pre].lc, tr[cur].lc, l, mid, L);
    res += query(tr[pre].rc, tr[cur].rc, mid+1, r, L);
    return res;
}

int jump(int u, int d) {
    int c = 0;
    while (d) {
        if (d&1) u = par[u][c];
        c++; d >>= 1;
    }
    return u;
}

void init() {
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        int val; cin >> val;
        arr[i] = val;
        par[i][0] = pos[val];
        pos[val] = i;
    }
    root[0] = build(0, n);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= 18; j++) {
            par[i][j] = par[par[i][j-1]][j-1];
        }
        pr[i] = jump(i, k);
        root[i] = insert(root[i-1], 0, n, pr[i]);
    }
}

int main() {

    init();
    int Q; cin >> Q;
    while (Q--) {
        int l,r; cin >> l >> r;
        l = (l + last) % n + 1;
        r = (r + last) % n + 1;
        if (l > r) swap(l,r);
        last = query(root[l-1], root[r], 0, n, l);
        cout << last << "\n";
    }
}

例7 CF1404C Fixed Point Removal

题意

给定长度为 $n$ 的数组 $a_1,a_2,…,a_n$。

我们每次可以选择 $i$ 使得 $a_i = i$,然后将它删掉。删掉后,剩下的部分会合并在一起。

现在有 $q$ 个询问,每次询问 $(x,y)$,回答:

如果将 $a_1,a_2,…,a_x$ 赋值为 $n+1$,$a_n,a_{n-1},…,a_{n-y+1}$ 赋值为 $n+1$ (实际上,就是让 $a$ 的前 $x$ 个元素 和 后 $y$ 个元素变得无法删除。),我们能最多删除多少个元素?

询问之间完全独立,互不影响。

其中,$1 \leq n,q \leq 3 \times 10^5, a_i \in [1,n], x,y \geq 0, x+y < n$

题解

首先,我们先不考虑询问的问题。

对于原数组,如何求出最多删除的元素个数?

我们会发现,当我们删除一个元素时,它右边的所有元素的 index 都会减少 $1$,这可能导致右边的元素又出现了可删除的。

同时我们发现,如果对于 $i$,有 $a_i > i$,则无论怎么删除,$i$ 只减不增,这样 $a_i = i$ 永远不可能成立。

这意味着,当出现 $a_i = i$ 时,就要立刻删除,如果它左边的元素被删了,它就再也删不掉了。

所以对于原数组,最优的删除方法是:

从右侧开始删,删完一个数字后,判断右边有没有出现新的可删数字,如果有的话继续从右开始。


有了贪心的删除策略,我们不能简单的模拟删除过程,因为这是 $O(n^2)$ 的。

每次删除操作,我们需要对右边的元素的index 进行区间减 $1$,很明显可以用线段树来维护。

不妨将问题转化一下,令 $b_i = i - a_i$:

  1. 当 $b_i = 0$ 时,可以删除,删除时,将 $[i+1,n]$ 的所有 $b_i$ 都减去 $1$。
  2. 当 $b_i < 0$ 时,永远都不可能删除。
  3. 当 $b_i > 0$ 时,如果它未来的某时刻被减为 $0$ 了,则可以被删掉。

那我们在线段树里面维护一个 $min$,代表区间最小值。

如果一个元素 $< 0$,则直接把它赋值为 $10^9$,无论怎么减,它都不可能等于 $0$,代表着它是一个无效元素(永远不可能被删除,或者已经被删除)。

每次询问一下区间最小值的位置,如果最小值为 $0$,就把这个位置 $i$ 的元素删掉,然后将 $[i+1,n]$ 都减去 $1$,然后将 $b_i$ 设为 $10^9$(代表已删除)。


现在有了 无询问 情况下的解,有询问怎么办?

我们观察一下样例数组(长度为 $13$)

$[2,2,3 ,9 ,5 ,4 ,6 ,5 ,7, 8, 3, 11, 13]$

我们删除的 index 顺序(指原数组的index)是:$13,5,12,7,10,9,3,8,6,2,11$。

假设我们询问了 $x=3, y=1$,这意味着 $1,2,3$ 和 $13$ 都不能出现在这个删除序列中了(我们给它们 打上标记),并且它后面的 所有 比它大 的数字也被打上标记了(因为它实际上依赖前面的数字,但是前面的这些数字无法被删了)!如下:

找 $1$,$1$ 不在删除序列中,忽略。

找 $2$,$2$ 在删除序列中,后面比它大的数字有 $11$,所以 $2,11$ 打上标记。

找 $3$,$3$ 在删除序列中,后面比它大的数字有 $8,6,2,11$,所以 $3,8,6,2,11$ 打上标记。

找 $13$,$13$ 在删除序列中,后面不存在比它大的数字。所以 $13$ 打上标记。

以上,剩下没有被标记的,只有 $5,12,7,10,9$,共 $5$ 个数字,所以答案为 $5$。


将上述的模拟过程,总结一下就是:

设删除序列(就是上面的 $13,5,12,7,10,9,3,8,6,2,11$)为 $c$。

那么,每次询问 $(x,y)$,所有 未被标记 的index $i$,必须得满足以下的所有条件:

  1. $c_i \in [x+1, n-y]$
  2. $pre_i \geq x+1$

其中,$pre_i$ 代表 $\min \{ c_1,c_2,…,c_{i-1}\}$。这是因为,如果 $pre_i \leq x$,则 $i$ 前面必然存在一个数字被标记了,所以 $i$ 也要被标记。


现在问题就转化为:

给定一个长度为 $n$ 的数组,每个元素是 $(c, pre)$ 的形式。

每次询问 $(x,y)$,求有多少个 $i \in [1,n]$ 满足以下所有条件:

  1. $c_i \in [x+1, n-y]$
  2. $pre_i \geq x+1$

有离线和在线两种方法。


在线做法:

将数组根据 $c$ 的值,sort一下。

每次询问的时候,先用二分找到左边界 $l$,满足 $c_l \geq x+1$,且 $l$ 尽可能小。

再二分找到右边界 $r$,满足 $c_r \leq n-y$,且 $r$ 尽可能大。

然后问题就转化为,求 $[l,r]$ 之间,有多少个元素 $i$,满足:

  1. $i \in [l,r]$
  2. $pre_i \geq x+1$

那这就是一个标准的主席树问题了。

复杂度:$O(n\log^2n)$


离线做法:

将所有的询问,根据 $x$ 的值,从大到小 进行sort。

然后,将数组根据 $pre$ 的值,从大到小 进行sort。

所以,我们在回答每个询问 $(x,y)$ 的时候,我们只需要考虑 $pre_i \geq x+1$ 的部分。

也就是说我们可以开一个线段树,维护 $c_i$ 的值。

假如我们当前处理到了 询问 $(x,y)$,我们只在线段树内维护所有满足 $pre_i \geq x+1$ 的 $c_i$ 即可。

• 这本质上是一个,将数组内的元素,根据 $pre_i$ 的值,逐一插入到线段树中的过程。

复杂度:$O(n\log n)$


总结:

本题的思考流程分为以下几个步骤:

  1. 找到最优的删除策略。
  2. 对于无询问状态下,如何模拟删除过程。
  3. 对于有询问状态下,如何模拟删除过程。
  4. 如何处理一个 类似二维数点 的问题。
在线做法(主席树)
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
 
int n,q;
struct node {
    int idx, val;
} arr[maxn];
 
struct tree_node {
    int minval, idx;
    int lazy = 0;
} tr[maxn<<2];
 
void push_up(int cur) {
    tr[cur].minval = min(tr[cur<<1].minval, tr[cur<<1|1].minval);
}
 
void push_down(int cur) {
    if (tr[cur].lazy == 0) return;
    int lazy = tr[cur].lazy;
    tr[cur].lazy = 0;
    int l = cur<<1, r = cur<<1|1;
    tr[l].lazy += lazy, tr[r].lazy += lazy;
    tr[l].minval += lazy, tr[r].minval += lazy;
}
 
void build(int cur, int l, int r) {
    if (l == r) {
        tr[cur].minval = arr[l].val;
        tr[cur].idx = arr[l].idx;
        return;
    }
    int mid = (l+r) >> 1;
    build(cur<<1, l, mid);
    build(cur<<1|1, mid+1, r);
    push_up(cur);
}
 
void update(int cur, int l, int r, int L, int R, int x) {
    if (L > R) return;
    if (l >= L && r <= R) {
        tr[cur].lazy += x;
        tr[cur].minval += x;
        return;
    }
    int mid = (l+r) >> 1;
    push_down(cur);
    if (L <= mid) update(cur<<1, l, mid, L, R, x);
    if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
    push_up(cur);
}
 
// query if [1,n] has minval = 0, (query right first), if yes, return the idx (original)
int query(int cur, int l, int r) {
    if (tr[cur].minval != 0) return -1;
    if (l == r) {
        return tr[cur].idx;
    }
    push_down(cur);
    int mid = (l+r) >> 1;
    if (tr[cur<<1|1].minval == 0) return query(cur<<1|1, mid+1, r);
    return query(cur<<1, l, mid);
}
 
// delete the element in p
void del(int cur, int l, int r, int p) {
    if (l == r) {
        tr[cur].minval = 1e9;
        return;
    }
    int mid = (l+r) >> 1;
    push_down(cur);
    if (p <= mid) del(cur<<1, l, mid, p);
    if (p > mid) del(cur<<1|1, mid+1, r, p);
    push_up(cur);
}
 
struct Num {
    int val, pre;
} seq[maxn];
int tail = 0;
 
void init() {
    build(1, 1, n);
    while (1) {
        int p = query(1, 1, n);
        if (p == -1) break;
        del(1, 1, n, p);
        seq[++tail] = {p, (int)1e9};
        update(1, 1, n, p+1, n, -1);
    }
    if (!tail) return;
    for (int i = 2; i <= tail; i++) {
        seq[i].pre = min(seq[i-1].pre, seq[i-1].val);
    }
    sort(seq+1, seq+tail+1, [](auto a, auto b) {
        return a.val < b.val;  // 根据 c 的值先进行sort
    });
}
 
void debug() {
    printf("tail = %d\n",tail);
    for (int i = 1; i <= tail; i++) {
        printf("%d ", seq[i].val);
    }
    printf("\n");
}
 
struct persistent_tree_node {
    int lc, rc, cnt;
} ptr[maxn<<5];
int root[maxn], id = 0;
 
int insert(int pre, int l, int r, int p) {
    int cur = ++id;
    ptr[cur] = ptr[pre];
    ptr[cur].cnt++;
    if (l == r) return cur;
    int mid = (l+r) >> 1;
    if (p <= mid) ptr[cur].lc = insert(ptr[pre].lc, l, mid, p);
    if (p > mid) ptr[cur].rc = insert(ptr[pre].rc, mid+1, r, p);
    return cur;
}
 
int query(int cur, int pre, int l, int r, int L, int R) {
    if (l >= L && r <= R) return ptr[cur].cnt - ptr[pre].cnt;
    int mid = (l+r) >> 1;
    int res = 0;
    if (L <= mid) res += query(ptr[cur].lc, ptr[pre].lc, l, mid, L, R);
    if (R > mid) res += query(ptr[cur].rc, ptr[pre].rc, mid+1, r, L, R);
    return res;
}
 
// return the smallest index, where seq[i].val >= x
int search_down(int x) {
    int l = 1, r = tail;
    int ans = tail+1;
    while (l <= r) {
        int mid = (l+r) >> 1;
        if (seq[mid].val >= x) {
            ans = mid;
            r = mid-1;
        } else l = mid+1;
    }
    return ans;
}
 
// return the largest index, where seq[i].val <= x
int search_up(int x) {
    int l = 1, r = tail;
    int ans = 0;
    while (l <= r) {
        int mid = (l+r) >> 1;
        if (seq[mid].val <= x) {
            ans = mid;
            l = mid+1;
        } else r = mid-1;
    }
    return ans;
}
 
int main() {
    cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        cin >> arr[i].val;
        arr[i].idx = i;
        int val = arr[i].val, idx = i;
        if (val > idx) arr[i].val = 1e9;
        else arr[i].val = i - val;
    }
    init();
    for (int i = 1; i <= tail; i++) {
        root[i] = insert(root[i-1], 1, n, seq[i].pre);
    }
 
    while (q--) {
        int x,y; cin >> x >> y;
        int ans = 0;
 
        int l = search_down(x+1);
        int r = search_up(n-y);
        if (l > r) ans = 0;
        else {
            ans = query(root[r], root[l-1], 1, n, x+1, n);
        }
 
        // 以下是暴力的做法:
        // for (int i = 1; i <= tail; i++) {
        //     if (seq[i].val >= x+1 && seq[i].val <= n-y && seq[i].pre >= x+1) ans++;
        // }

        cout << ans << "\n";
    }
}
离线做法
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
 
int n,q;
struct node {
    int idx, val;
} arr[maxn];
 
struct tree_node {
    int minval, idx;
    int lazy = 0;
} tr[maxn<<2];
 
void push_up(int cur) {
    tr[cur].minval = min(tr[cur<<1].minval, tr[cur<<1|1].minval);
}
 
void push_down(int cur) {
    if (tr[cur].lazy == 0) return;
    int lazy = tr[cur].lazy;
    tr[cur].lazy = 0;
    int l = cur<<1, r = cur<<1|1;
    tr[l].lazy += lazy, tr[r].lazy += lazy;
    tr[l].minval += lazy, tr[r].minval += lazy;
}
 
void build(int cur, int l, int r) {
    if (l == r) {
        tr[cur].minval = arr[l].val;
        tr[cur].idx = arr[l].idx;
        return;
    }
    int mid = (l+r) >> 1;
    build(cur<<1, l, mid);
    build(cur<<1|1, mid+1, r);
    push_up(cur);
}
 
void update(int cur, int l, int r, int L, int R, int x) {
    if (L > R) return;
    if (l >= L && r <= R) {
        tr[cur].lazy += x;
        tr[cur].minval += x;
        return;
    }
    int mid = (l+r) >> 1;
    push_down(cur);
    if (L <= mid) update(cur<<1, l, mid, L, R, x);
    if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
    push_up(cur);
}
 
// query if [1,n] has minval = 0, (query right first), if yes, return the idx (original)
int query(int cur, int l, int r) {
    if (tr[cur].minval != 0) return -1;
    if (l == r) {
        return tr[cur].idx;
    }
    push_down(cur);
    int mid = (l+r) >> 1;
    if (tr[cur<<1|1].minval == 0) return query(cur<<1|1, mid+1, r);
    return query(cur<<1, l, mid);
}
 
// delete the element in p
void del(int cur, int l, int r, int p) {
    if (l == r) {
        tr[cur].minval = 1e9;
        return;
    }
    int mid = (l+r) >> 1;
    push_down(cur);
    if (p <= mid) del(cur<<1, l, mid, p);
    if (p > mid) del(cur<<1|1, mid+1, r, p);
    push_up(cur);
}
 
struct Num {
    int val, pre;
} seq[maxn];
int tail = 0;
 
void init() {
    build(1, 1, n);
    while (1) {
        int p = query(1, 1, n);
        if (p == -1) break;
        del(1, 1, n, p);
        seq[++tail] = {p, (int)1e9};
        update(1, 1, n, p+1, n, -1);
    }
    if (!tail) return;
    for (int i = 2; i <= tail; i++) {
        seq[i].pre = min(seq[i-1].pre, seq[i-1].val);
    }
    sort(seq+1, seq+tail+1, [](auto a, auto b) {
        return a.pre > b.pre;  // 根据 pre 进行 sort
    });
}
 
struct Query_node {
    int x,y,id;
} que[maxn];
int ans[maxn];
 
struct tree_node2 {
    int cnt;
} tr2[maxn<<2];
 
void insert(int cur, int l, int r, int p) {
    tr2[cur].cnt++;
    if (l == r) {
        return;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) insert(cur<<1, l, mid, p);
    else insert(cur<<1|1, mid+1, r, p);
}
 
int Query(int cur, int l, int r, int L, int R) {
    if (L > R) return 0;
    if (l >= L && r <= R) return tr2[cur].cnt;
    int mid = (l+r) >> 1;
    int res = 0;
    if (L <= mid) res += Query(cur<<1, l, mid, L, R);
    if (R > mid) res += Query(cur<<1|1, mid+1, r, L, R);
    return res;
}
 
int main() {
    cin >> n >> q;
    for (int i = 1; i <= n; i++) {
        cin >> arr[i].val;
        arr[i].idx = i;
        int val = arr[i].val, idx = i;
        if (val > idx) arr[i].val = 1e9;
        else arr[i].val = i - val;
    }
    init();
 
    for (int i = 1; i <= q; i++) {
        cin >> que[i].x >> que[i].y;
        que[i].id = i;
    }
    sort(que+1, que+q+1, [](auto a, auto b){
        return a.x > b.x;
    });
 
    int p = 0;
    for (int i = 1; i <= q; i++) {        
        int x = que[i].x, y = que[i].y, id = que[i].id;
        while (p+1 <= tail && seq[p+1].pre >= x+1) {  // 满足 pre >= x+1 就插入
            p++;
            insert(1, 1, n, seq[p].val);  // 逐一插入进线段树
        }
        ans[id] = Query(1, 1, n, x+1, n-y);
    }
 
    for (int i = 1; i <= q; i++) cout << ans[i] << "\n";
}

参考链接

  1. https://www.luogu.com.cn/blog/Fighting-Naruto/solution-p3834