主要记录一些遇到的线段树/分块例题。

例1 CF438D

题意

给定 $N$ 个正整数和 $M$ 个询问,询问有 3 种:

$1 ~ l ~ r$:输出 $\sum\limits_{i=l}^r a_i$

$2 ~ l ~ r ~ x$:将 $a_l$ 到 $a_r$ 的所有数取 $\text{mod } x$

$3 ~ k ~ x$:将 $a_k = x$

其中,$1 \leq N,M \leq 10^5, 1 \leq a_i,x \leq 10^9$

题解

本题分块和线段树都可以做,我们这里用 线段树 来做。

主要是需要考虑 区间取模 怎么办?

回忆一下分块例题中的 区间开方,我们维护了一个额外的tag表示这个区间是否为 全0/全1,如果不是 全0/全1 就暴力开方。

取模操作同理,我们发现,如果 $a_i > x$,那么 $a_i \text{ mod } x \leq \frac{a_i}{2}$,所以对于每个 $a_i$,最多只会被 $\text{mod}$ $\log (a_i)$ 次!

所以,我们维护一个 区间最大值,取模时,检查一下 区间最大值是否大于 $x$

  1. 如果大于 $x$,就继续往下递归。
  2. 如果小于 $x$,就直接返回。

base case 就是区间长度为 $1$ 时,直接对这个元素开方即可。

例2 CF558E

题意

给定一个长度为 $n$ 的string,仅包含小写字母。给 $q$ 个询问:

$l,r,k$:将string的 $[l,r]$ 进行sort,如果 $k=1$ 就升序,$k=0$ 降序。

输出所有询问结束后的string。

其中,$1\leq n \leq 10^5, 1 \leq q\leq 50000$

题解

线段树来处理。

首先,string只包含小写字母。所以每个 node 可以维护一个 cnt[26] 代表这个node里的每个字母出现的次数。

其次,对于排序,我们在每个 node 中维护一个标记 $k$ 来代表该区间是否排序好了。若 $k=0$ 代表降序,$k=1$ 代表升序,$k = -1$ 代表乱序。

最后,维护一个 $lazy$ 标记,我们会注意到对于一个node而言,若 $lazy = 1$,那么这个 node 必然是排序好了的!(要么 $k=0$ ,要么 $k=1$)。

有了以上信息,我们就可以进行 sort 操作了!


sort $[L,R]$ 的时候,步骤如下:

  1. 提取出 $[L,R]$ 内每个字母出现的次数。
  2. 求出 $[L,R]$ 与 $[l,mid]$(当前node 左child的范围)的区间交集 $[l_1,r_1]$
  3. 求出 $[L,R]$ 与 $[mid+1, r]$(当前node 右child的范围)的区间交集 $[l_2,r_2]$
  4. 用指针遍历 $a-z$(或者 $z-a$),根据升序/降序 将 字母出现的次数分别填充左child和右childcnt[] 中。(注意,这里的填充是指:先填充进一个 int* buf = new int[26]; 的动态数组,然后buf[] 作为参数,再往下传递,直到区间完全覆盖,再将 buf[] 的内容复制进 cnt[] 里)。
代码
using namespace std;
#include <bits/stdc++.h>
const int maxn = 1e5+5;
const int maxm = 2e5+10;
 
int n,q;
char arr[maxn];
 
struct Node {
    int l,r,k,cnt[26];
    bool lazy = 0;
} tr[4*maxn];
int tmp[26];
 
inline int len(int cur) { return tr[cur].r - tr[cur].l + 1; }
 
void push_up(int cur) {
    int lc = cur<<1, rc = lc+1;
    for (int i = 0; i < 26; i++) {
        tr[cur].cnt[i] = tr[lc].cnt[i] + tr[rc].cnt[i];
    }
    if (tr[lc].k != -1 && tr[lc].k == tr[rc].k) tr[cur].k = tr[lc].k;  // k = 1: increasing, k = 0: decreasing
    else tr[cur].k = -1;
}
 
void put(int cur, int k) {
    int lc = cur<<1, rc = lc+1;
    memset(tr[lc].cnt, 0, sizeof(tr[lc].cnt));
    memset(tr[rc].cnt, 0, sizeof(tr[rc].cnt));
    memcpy(tmp, tr[cur].cnt, sizeof(tmp));
 
    int lsz = len(lc), rsz = len(rc);
    if (k) {
        int p = 0;
        while (p < 26 && lsz) {
            int delta = min(lsz, tmp[p]);
            tr[lc].cnt[p] += delta;
            lsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p++;
        }
        while (p < 26 && rsz) {
            int delta = min(rsz, tmp[p]);
            tr[rc].cnt[p] += delta;
            rsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p++;
        }
    } else {
        int p = 25;
        while (p >= 0 && lsz) {
            int delta = min(lsz, tmp[p]);
            tr[lc].cnt[p] += delta;
            lsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p--;
        }
        while (p >= 0 && rsz) {
            int delta = min(rsz, tmp[p]);
            tr[rc].cnt[p] += delta;
            rsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p--;
        }
    }
    assert(lsz == 0);
    assert(rsz == 0);
}
 
void push_down(int cur) {
    if (!tr[cur].lazy) return;
    int lc = cur<<1, rc = lc+1;
    tr[cur].lazy = 0;
    tr[lc].lazy = tr[rc].lazy = 1;
    assert(tr[cur].k != -1);
 
    int k = tr[cur].k;
    tr[lc].k = k;
    tr[rc].k = k;
    put(cur,k);
}
 
void build(int cur, int L, int R) {
    tr[cur].l = L, tr[cur].r = R;
    if (L == R) {
        memset(tr[cur].cnt, 0, sizeof(tr[cur].cnt));
        tr[cur].cnt[arr[L]-'a'] = 1;
        return;
    }
    int mid = (L+R) >> 1;
    build(cur<<1, L, mid);
    build(cur<<1|1, mid+1, R);
    push_up(cur);
}
 
int ress[26];  // 每次query的结果会存到这里
 
void clear(int* buf) {
    for (int i = 0; i < 26; i++) buf[i] = 0;
}
 
int inter(int l1, int r1, int l2, int r2) {  // 求区间交集的长度
    int l = max(l1,l2), r = min(r1,r2);
    return max(0,r-l+1);
}
 
void update(int cur, int L, int R, int k, int* res) {  // 注意参数里有个动态数组 res
    int lc = cur<<1, rc = lc+1;
    int l = tr[cur].l, r = tr[cur].r;
    if (l >= L && r <= R) {
        tr[cur].k = k;
        tr[cur].lazy = 1;
        for (int i = 0; i < 26; i++) tr[cur].cnt[i] = res[i];  // 区间完全覆盖,复制到 cnt 中。
        clear(res);  // 记得清空,之后可能还要用
        return;
    }
    int mid = (l+r) >> 1;
    int lsz = inter(l,mid,L,R), rsz = inter(mid+1,r,L,R);
 
    int* buf = new int[26];  // 这里采用了动态数组
    for (int i = 0; i < 26; i++) buf[i] = 0;  //注意new出来的需要先清空一下,另外不能使用 memset(因为是指针)
 
    if (k) {
        int p = 0;
        while (p < 26 && lsz) {
            int delta = min(lsz, res[p]);
            buf[p] += delta;
            lsz -= delta;
            res[p] -= delta;
            if (!res[p]) p++;
        }
        if (L <= mid) {
            update(lc, L, R, k, buf);  // 传递 buf
        }
 
        while (p < 26 && rsz) {
            int delta = min(rsz, res[p]);
            buf[p] += delta;
            rsz -= delta;
            res[p] -= delta;
            if (!res[p]) p++;
        }
 
        if (R > mid) {
            update(rc, L, R, k, buf);  // 传递 buf
        }
 
    } else {
        int p = 25;
        while (p >= 0 && lsz) {
            int delta = min(lsz, res[p]);
            buf[p] += delta;
            lsz -= delta;
            res[p] -= delta;
            if (!res[p]) p--;
        }
        if (L <= mid) {
            update(lc, L, R, k, buf);  // 传递 buf
        }
        while (p >= 0 && rsz) {
            int delta = min(rsz, res[p]);
            buf[p] += delta;
            rsz -= delta;
            res[p] -= delta;
            if (!res[p]) p--;
        }
        if (R > mid) {
            update(rc, L, R, k, buf);  // 传递 buf
        }
    }
    delete[] buf;
 
    assert(lsz == 0);
    assert(rsz == 0);
    push_up(cur);
}
 
void query(int cur, int L, int R) {  // 求区间内每个字母出现的个数
    int l = tr[cur].l, r = tr[cur].r;
    if (l >= L && r <= R) {
        for (int i = 0; i < 26; i++) ress[i] += tr[cur].cnt[i];
        return;
    }
 
    int lc = cur<<1, rc = lc+1;
    push_down(cur);
    int mid = (l+r) >> 1;
    if (L <= mid) query(lc, L, R);
    if (R > mid) query(rc, L, R);
    push_up(cur);
}
 
void printans() {
    for (int i = 1; i <= n; i++) {
        memset(ress, 0, sizeof(ress));
        query(1,i,i);
        for (int j = 0; j < 26; j++) {
            if (ress[j]) {
                printf("%c",(char)('a'+j));
                ress[j]--;
                break;
            }
        }
    }
    printf("\n");
}
 
int main() {
 
    scanf("%d%d",&n,&q);
    scanf("%s", arr+1);
    build(1, 1, n);
    while (q--) {
        int l,r,k; scanf("%d%d%d",&l,&r,&k);
        memset(ress, 0, sizeof(ress));
        query(1,l,r);
        update(1,l,r,k,ress);
    }
    printans();
}

例3 洛谷P1972 HH的项链

题意

给定一个长度为 $N$ 的数组 $a_1,a_2,…,a_n$,以及 $m$ 个询问,每次询问 $[L,R]$ 之间有多少个不同的数。

其中,$1 \leq n,m,a_i \leq 10^6$

题解

我们可以发现,如果我们固定了询问的右端点 $R$,那么无论 $L$ 为多少,在 $R$ 的左侧的所有重复数字中,仅保留最靠右的一个 copy 即可。

例如 $arr = 1,3,2,1,7,1$,那么我们在遍历到 $i = 4$ 时,我们仅需要保留最后一个 $1$ (也就是 index 为 $4$ 的数字)。

由上,在处理 区间内不同的数 时,一个常见的套路是:

  1. 离线处理所有询问,按右端点 $R$ 排序。
  2. 从左到右遍历数组,遍历到 $i$ 时,对于所有 $a_i$ 的 copy,仅保留最靠右的那一个(也就是 $i$),之前的所有 copy 全部删除。
  3. 回答所有 $[L, i]$ 的询问。

那么对于本题,第一步是离线处理询问,按右端点 $R$ 排序。

第二步是遍历数组,遍历过程中维护一个 int pos[] 数组,其中 pos[val] 代表:在 $arr[1…i]$ 中,值为 val 的数 最靠右的 index

当我们遍历到 $i$ 时,令 int val = arr[i],将 pos[val] 处的数字 删掉(更新线段树),然后将 $i$ 处的数字 加入线段树,最后更新一下 pos[val] = i;

•本题中,删掉就是将 线段树的位置 $i$ 的值 减去$1$,加上就是将 线段树的位置 $i$ 的值 加上$1$。

然后回答所有 以 $i$ 为右端点的询问(求 $[L,i]$ 的和即可)。

注:如果询问 在线 怎么办?可以使用主席树!(对于 pr[i] 建主席树)

具体做法参考:CF1422F Boring Queries

代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e6+5;

int n, m, arr[maxn];
struct node {
    int sum;
} tr[maxn<<2];

void push_up(int cur) {
    tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}

void update(int cur, int l, int r, int p, int x) {
    if (l == r) {
        tr[cur].sum += x; return;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) update(cur<<1, l, mid, p, x);
    else update(cur<<1|1, mid+1, r, p, x);
    push_up(cur);
} 

int query(int cur, int l, int r, int L, int R) {
    if (l >= L && r <= R) return tr[cur].sum;
    int mid = (l+r) >> 1;
    int res = 0;
    if (L <= mid) res += query(cur<<1, l, mid, L, R);
    if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
    return res;
}

int ans[maxn];
int pos[maxn];
struct Query {
    int id,l,r;
} q[maxn];

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    cin >> m;
    for (int i = 1; i <= m; i++) {
        int l,r;
        cin >> l >> r;
        q[i] = {i,l,r};
    }
    
    sort(q+1,q+m+1, [](auto a, auto b) {
        return a.r < b.r;  // 根据右端点离线
    });
    int ptr = 1;
    for (int i = 1; i <= n; i++) {
        int val = arr[i];
        if (pos[val]) update(1, 1, n, pos[val], -1);  // 删去 pos[val]
        update(1, 1, n, i, 1);  // 加上 i 

        pos[val] = i;  // 更新 pos[val]
        while (ptr <= m && q[ptr].r == i) {  // 回答所有以 i 作为右端点的询问
            int id = q[ptr].id, L = q[ptr].l, R = q[ptr].r;
            ans[id] = query(1, 1, n, L, R);
            ptr++;
        }
        if (ptr > m) break;
    }
    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例4 CF1000F One Occurrence

题意

给定一个长度为 $n$ 的数组,$m$ 个询问,每次询问一个区间 $[l,r]$,如果这个区间里存在只出现一次的数,输出这个数(如果有多个就输出任意一个),没有就输出 $0$。

其中 $n,m \leq 5 \times 10^5, 1 \leq a_i \leq 5 \times 10^5$。

题解

和例3类似的套路,也是离线处理询问(根据右端点sort),从左往右遍历,仅保留最靠右的复制。

问题在于怎么删除 和 加入数字?因为本题不再是求数量了,所以不能简单的加 $1$ 或者 减 $1$。

会发现,我们仅关心一个区间内是否存在 unique 的数字,对于一个询问 $[L,R]$ 内,我们只要看,是否存在 $i$ 使得 $arr[i]$ 的前一个复制 不在 $[L,R]$ 内。(也就是说,pos[val] 是否小于 $L$)

那么,我们用线段树维护一下 区间最小值 即可,其中 $[L,R]$ 的区间最小值就代表着所有$i \in [L,R]$ 中,pos[arr[i]] 的最小值。如果一个区间 $[L,R]$ 的最小值 $\geq L$,那么答案不存在,否则答案存在。

那么,删除位置 $i$ 就是将它这个位置上的值设为 $inf$。

加入位置 $i$ 的数,就是将它这个位置上的值设为 pos[arr[i]]

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;

struct node {
    int pre = 1e9, idx = 0;  // pre 代表 pos[val] 的值,idx代表这个最小值对应的 index
} tr[maxn<<2];
int pos[maxn], n, m, arr[maxn];
struct Query {
    int id,l,r;
} q[maxn];

void push_up(int cur) {
    int lpre = tr[cur<<1].pre, rpre = tr[cur<<1|1].pre;
    if (lpre < rpre) tr[cur].pre = lpre, tr[cur].idx = tr[cur<<1].idx;
    else tr[cur].pre = rpre, tr[cur].idx = tr[cur<<1|1].idx;
}

void update(int cur, int l, int r, int p, int x) {
    if (l == r) {
        tr[cur].pre = x, tr[cur].idx = l;
        return;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) update(cur<<1, l, mid, p, x);
    else update(cur<<1|1, mid+1, r, p, x);
    push_up(cur);
}

pii query(int cur, int l, int r, int L, int R) {
    if (l >= L && r <= R) return {tr[cur].pre, tr[cur].idx};
    int mid = (l+r) >> 1;
    pii r1 = {1e9, 0}, r2 = {1e9, 0};
    if (L <= mid) r1 = query(cur<<1, l, mid, L, R);
    if (R > mid) r2 = query(cur<<1|1, mid+1, r, L, R);
    if (r1.first > r2.first) return r2;
    else return r1;
}
int ans[maxn];

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    cin >> m;
    for (int i = 1; i <= m; i++) {
        int l,r; cin >> l >> r;
        q[i] = {i,l,r};
    }
    sort(q+1, q+m+1, [](auto a, auto b) {
        return a.r < b.r;
    });

    int ptr = 1;
    for (int i = 1; i <= n; i++) {
        int val = arr[i];
        update(1, 1, n, i, pos[val]);  // 在i处 加入数字 pos[val]
        if (pos[val]) update(1, 1, n, pos[val], 1e9);  // 在pos[val] 处删除数字(设为 inf)
        pos[val] = i;
        while (ptr <= m && q[ptr].r == i) {
            int L = q[ptr].l, R = q[ptr].r;
            pii res = query(1, 1, n, L, R);
            int id = q[ptr].id;

            if (res.first >= L) ans[id] = 0;  // 如果这个区间内,所有 pos[val] 都 >= L
            else ans[id] = arr[res.second];  //  否则,答案存在
            ptr++;
        }
        if (ptr > m) break;
    }
    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例5 CF803G Periodic RMQ Problem

题意

给定一个长度为 $n$ 的数组,将这个数组复制为 $k$ 份并且拼接在一起。

然后回答 $Q$ 个询问,分两种询问:(所有询问都在拼接后的数组上进行)

$1 ~ l ~ r ~ x$:将 $[l,r]$ 中的所有元素改为 $x$。

$2 ~ l ~ r$:询问 $[l,r]$ 中的最小值。

其中,$1 \leq n \leq 10^5, 1 \leq k \leq 10^4, 1 \leq Q \leq 10^5, ~l,r \in [1,n]$

题解

注意到数组的总长度可以达到 $10^9$。但是询问只有 $Q=10^5$,我们于是想到了动态开点线段树。

但是注意到,这个数组是有初始值的,按理说应该把树建好,不能动态开点。

所以我们考虑 不建树,而是在开点的时候,把这个点的初始状态处理好(就是在没有任何修改的情况下,这个点的初始状态),然后再对这个点进行操作。

所以,对于一个点对应的区间 $[l,r]$,怎么处理初始状态?

注意到,由于数组是一个循环节,所以可以分以下三种情况:

  1. $[l,r]$ 的长度 $\geq n$,它的初始最小值就是数组 $[1,n]$ 的最小值。
  2. $[l,r]$ 的长度 $< n$,且属于同一个循环节,那么用 ST表 预处理一下 $[l \text{ mod } n,r\text{ mod } n]$ 的最小值即可。
  3. $[l,r]$ 的长度 $< n$,且不属于同一个循环节,那么它的初始最小值就是 $[l \text{ mod } n, n] \bigcup [1, r\text{ mod } n]$

开点的函数是代码里的 build(),注意到只要我们访问到了 cur,并且需要访问它的任意一个 child 时,需要把左右两个子树都开点,这保证了 pushup() 的正确性。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e7+5;
const int maxm = 1e5+5;

int n,k,Q;
int a[maxm];
int belong(int x) {
    return (x-1) / n + 1;
}

struct Node {
    int lc, rc, lazy, val;
} tr[maxn];
int id = 0;

int st[maxm][18];
int bin[maxm];
void build_st() {
    bin[1] = 0, bin[2] = 1;
    for (int i = 3; i < maxm; i++) bin[i] = bin[i>>1] + 1;
    for (int i = 1; i <= n; i++) {
        st[i][0] = a[i];
    }
    for (int j = 1; j < 18; j++) {
        for (int i = 1; i + (1<<j) - 1 <= n; i++) {
            st[i][j] = min(st[i][j-1], st[i+(1<<(j-1))][j-1]);
        }
    }
}
int ask_st(int l, int r) {
    if (l > r) swap(l,r);
    int j = bin[r-l+1];
    return min(st[l][j], st[r-(1<<j)+1][j]);
}

void build(int& cur, int l, int r) {
    if (cur) return;
    cur = ++id;
    if (r-l+1 >= n) {
        tr[cur].val = ask_st(1, n);
        return;
    }
    int bl = belong(l), br = belong(r);
    l %= n, r %= n;
    if (!l) l = n;
    if (!r) r = n;
    if (bl == br) {
        tr[cur].val = ask_st(l,r);
        return;
    }
    // bl != br
    tr[cur].val = min(ask_st(l, n), ask_st(1, r));
}

void push_down(int cur) {
    if (!tr[cur].lazy) return;
    int lazy = tr[cur].lazy; 
    tr[cur].lazy = 0;
    int lc = tr[cur].lc, rc = tr[cur].rc;
    tr[lc].lazy = tr[lc].val = tr[rc].lazy = tr[rc].val = lazy;
}

void push_up(int cur) {
    tr[cur].val = min(tr[tr[cur].lc].val, tr[tr[cur].rc].val);
}

void update(int cur, int l, int r, int L, int R, int x) {
    if (l >= L && r <= R) {
        tr[cur].lazy = tr[cur].val = x;
        return;
    }
    int mid = (l+r) >> 1;
    build(tr[cur].lc, l, mid);
    build(tr[cur].rc, mid+1, r);
    push_down(cur);
    if (L <= mid) update(tr[cur].lc, l, mid, L, R, x);
    if (R > mid) update(tr[cur].rc, mid+1, r, L, R, x);
    push_up(cur);
}

int query(int cur, int l, int r, int L, int R) {
    if (l >= L && r <= R) return tr[cur].val;
    int mid = (l+r) >> 1;
    build(tr[cur].lc, l, mid);
    build(tr[cur].rc, mid+1, r);
    push_down(cur);
    int lres = 1e9, rres = 1e9;
    if (L <= mid) lres = query(tr[cur].lc, l, mid, L, R);
    if (R > mid) rres = query(tr[cur].rc, mid+1, r, L, R);
    push_up(cur);
    return min(lres, rres);
}

int rt = 0;
int main() {
    fastio;
    cin >> n >> k;
    for (int i = 1; i <= n; i++) cin >> a[i];
    build_st();
    cin >> Q;
    build(rt, 1, n*k);
    for (int i = 1; i <= Q; i++) {
        int op,l,r,x;
        cin >> op;
        if (op == 1) {
            cin >> l >> r >> x;
            update(1, 1, n*k, l, r, x);
        } else {
            cin >> l >> r;
            int res = query(1, 1, n*k, l, r);
            cout << res << "\n";
        }
    }
}

例6 CF gym103687F. Easy Fix

题意

给定一个 permutation $p_1, p_2 … p_n$。

定义 $A_i$ 为 $i$ 左边比 $p_i$ 小的元素数量,$B_i$ 为 $i$ 右边比 $p_i$ 大的元素数量,$C_i = \min(A_i, B_i)$。

给定 $m$ 个询问,每次询问 $L, R$,回答在交换 $p_L, p_R$ 的位置以后,$\sum\limits_{i=1}^n C_i$ 的值?

其中,$n \leq 10^5, m \leq 2 \times 10^5$。

题解

先讨论一下交换后有哪些元素受到影响?

很明显,$L, R$ 位置会受到影响,$i \in (L,R)$ 且 $p_i \in (p_L, p_R)$ 的会受到影响,其他的均无影响。


我们先讨论 $i \in (L,R)$ 且 $p_i \in (p_L, p_R)$ 这种情况:

Case1: 如果 $a_u < a_v$,那么 $A_i - 1, B_i + 1$。

 Case1.1: 如果 $A_i \leq B_i$,那么 $C_i - 1$。

 Case1.2: 如果 $A_i > B_i + 1$,那么 $C_i + 1$。

Case2: 如果 $a_u > a_v$,那么 $A_i + 1, B_i - 1$。

 Case2.1: 如果 $A_i \geq B_i$,那么 $C_i - 1$。

 Case2.2: 如果 $A_i < B_i - 1$,那么 $C_i + 1$。


然后再讨论一下 $L, R$ 受到的影响:

只要先把 $C_L, C_R$ 的值从询问中减掉,然后重新找一下换了之后的 $C_L, C_R$ 即可。

• 需要注意当 $p_L > p_R$ 时,换了之后 $R$ 位置上的 $p_L$ 计算 $A_i$ 时会将 $p_R$ 忽略掉,所以要加 $1$。


好的,现在该讨论一下怎么获得 $A_i, B_i, C_i$。

首先,获得 $A_i$ 可以直接用树状数组维护 $p_i$(作为index),树状数组上的值为 0 和 1。遍历一下就可以求出初始状态的 $A_i$ 了。

然后 $B_i = p_i - 1 - A_i$,很好理解,这样 $C_i$ 也有了。


再讨论一下怎么处理询问?

我们将 Case1, Case2 分开处理:如果我们开一个线段树/树状数组,index 为 $p_i$,值为 $-1, 0, 1$ 来记录 Case1.1, Case1.2 的情况,那么:

每次询问就是询问一个区间 $(L,R)$ 内,满足 $(p_L, p_R)$ 的值的一些东西的和。

似乎是 主席树

确实,不过本题询问离线,Stupidcdd 教导我:

能用主席树做的,离线情况下就能用线段树/树状数组。

回忆一下主席树的作用:对于每个区间 $[L,R]$,维护一个权值线段树,原理其实是每个点维护一个线段树,$R$ 上的线段树减去 $L-1$ 位置的线段树而已。

那么询问离线的时候,我们只有一个树状数组,可以在遍历的过程中,获得前缀 $[1…i]$ 对应的树状数组,但是我们没有持久化,所以保存不下来。

因为树状数组的历史版本保存不下来,不妨从询问下手?

离线的一个常见套路:

对于询问的端点 $L,R$,每个端点开一个 vector 储存所有的询问,遍历到这个点时处理所有对应的询问,加到答案上。

这个题就是这样,我们想要知道 $(L,R)$ 内满足 $(p_L, p_R)$ 的值的一些东西的和,那么不妨在遍历到 $L$ 的时候询问一下 $(p_L, p_R)$,在遍历到 $R-1$ 的时候询问一下 $(p_L, p_R)$,然后一减,不就是答案了么?

struct Query_Info {
    int pl, pr;  // 询问 [pl, pr] 的部分
    int f;  // 符号
    int idx;  // 询问的 idx
    int type;  // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn];  // qinfo[i] 端点为 i 的需要询问


struct Single_Query {
    int x;  // 询问 < x 有多少个
    int idx;  // 询问的 idx
    int diff;  // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn];  // 单点询问 < x 有多少个
ll ans[maxn];

...
...
...
    // (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
    for (int j = 1; j <= m; j++) {
        auto q = queries[j];
        int u = q.first, v = q.second;
        if (u != v) {
            int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
            int type = (p[u] < p[v] ? 1 : 2);
            qinfo[u+1].push_back({mn, mx, -1, j, type});  // 由于处理时,还没加上,所以全部应该 + 1  (u -> u+1, v-1 -> v)
            qinfo[v].push_back({mn, mx, 1, j, type});
            ans[j] -= (C[u] + C[v]);  // 先去掉 C[u] + C[v] 的影响,后面直接加上
            single[u].push_back({p[v], j, 0});
            single[v].push_back({p[u], j, p[u] > p[v]});
        }
    }

    for (int i = 1; i <= n; i++) {
        for (auto q : qinfo[i]) {
            if (q.type == 1) {
                ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
            } else {
                ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
            }
        }
        for (auto q : single[i]) {
            int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
            ans[q.idx] += min(AA, BB);
        }

        if (A[i] <= B[i]) tr1.update(p[i], -1);
        if (A[i] > B[i] + 1) tr1.update(p[i], 1);
        if (A[i] >= B[i]) tr2.update(p[i], -1);
        if (A[i] < B[i] - 1) tr2.update(p[i], 1);
        tr3.update(p[i], 1);
    }

在这个题中:

  1. $(L,R)$ 我们通过离线询问,答案相减来实现。
  2. $(p_L, p_R)$ 我们通过权值线段树/权值树状数组来实现。
  3. 树状数组上面维护了 Case1.1, 1.2 (或者 2.1, 2.2) 的情况。

总结一下思路:

离线的关键在于从左向右遍历的时候,获得每个前缀对应的树状数组。

在离线 $(L,R)$ 时,将询问的信息扔给每个端点的 vector,在遍历到的时候再处理,处理时直接对 ans[] 进行操作。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+55;
const int maxm = 2e5+5;

int n, p[maxn], m;
pii queries[maxm];

struct BIT {
    int n, tr[maxn];
    inline int lowbit(int x) { return x & -x; }
    void update(int p, int val) {
        while (p <= n) {
            tr[p] += val;
            p += lowbit(p);
        }
    }
    // return sum[1...p]
    int query(int p) {
        int ans = 0;
        while (p > 0) {
            ans += tr[p];
            p -= lowbit(p);
        }
        return ans;
    }
} tr, tr1, tr2, tr3;


struct Query_Info {
    int pl, pr;  // 询问 [pl, pr] 的部分
    int f;  // 符号
    int idx;  // 询问的 idx
    int type;  // (pu < pv) : tr1, (pu > pv): tr2, type = 1 代表 tr1
};
int A[maxn], B[maxn], C[maxn];
int a[maxn], b[maxn], c[maxn];
vector<Query_Info> qinfo[maxn];  // qinfo[i] 端点为 i 的需要询问


struct Single_Query {
    int x;  // 询问 < x 有多少个
    int idx;  // 询问的 idx
    int diff;  // 如果 p[u] > p[v],那么换出去以后 A[v]需要+1
};
vector<Single_Query> single[maxn];  // 单点询问 < x 有多少个
ll ans[maxn];
ll csum = 0;

int main() {
    fastio;
    cin >> n;
    tr.n = tr1.n = tr2.n = tr3.n = n;
    for (int i = 1; i <= n; i++) cin >> p[i];
    cin >> m;
    for (int i = 1; i <= m; i++) {
        int l, r; cin >> l >> r;
        if (l >= r) swap(l, r);
        queries[i] = {l, r};
    }
    for (int i = 1; i <= n; i++) {
        A[i] = tr.query(p[i] - 1);  // 左边比它小
        B[i] = p[i] - 1 - A[i];  // 右边比它小
        C[i] = min(A[i], B[i]);
        csum += C[i];
        tr.update(p[i], 1);
    }
    for (int i = 1; i <= m; i++) ans[i] = csum;

    // (u,v) 的话需要让 v-1 的部分 (前缀 v-1 询问 (a_u, a_v) 的部分)减去 u 的部分(前缀 u 询问 (a_u, a_v) 的部分)
    for (int j = 1; j <= m; j++) {
        auto q = queries[j];
        int u = q.first, v = q.second;
        if (u != v) {
            int mn = min(p[u], p[v]) + 1, mx = max(p[u], p[v]) - 1;
            int type = (p[u] < p[v] ? 1 : 2);
            qinfo[u+1].push_back({mn, mx, -1, j, type});  // 由于处理时,还没加上,所以全部应该 + 1  (u -> u+1, v-1 -> v)
            qinfo[v].push_back({mn, mx, 1, j, type});
            ans[j] -= (C[u] + C[v]);  // 先去掉 C[u] + C[v] 的影响,后面直接加上
            single[u].push_back({p[v], j, 0});
            single[v].push_back({p[u], j, p[u] > p[v]});
        }
    }

    for (int i = 1; i <= n; i++) {
        for (auto q : qinfo[i]) {
            if (q.type == 1) {
                ans[q.idx] += q.f * (tr1.query(q.pr) - tr1.query(q.pl-1));
            } else {
                ans[q.idx] += q.f * (tr2.query(q.pr) - tr2.query(q.pl-1));
            }
        }
        for (auto q : single[i]) {
            int AA = tr3.query(q.x - 1) + q.diff, BB = max(0, q.x - AA - 1);
            ans[q.idx] += min(AA, BB);
        }

        if (A[i] <= B[i]) tr1.update(p[i], -1);
        if (A[i] > B[i] + 1) tr1.update(p[i], 1);
        if (A[i] >= B[i]) tr2.update(p[i], -1);
        if (A[i] < B[i] - 1) tr2.update(p[i], 1);
        tr3.update(p[i], 1);
    }

    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例7 Atcoder ABC237G. Range Sort Query

题意

给定一个长度为 $n$ 的permutation $p_1,p_2…,p_n$,和一个正整数 $x$。

给定 $q$ 个询问,询问有两种:

$1 ~ L ~ R$:将 $p[L…R]$ 从小到大排序。

$2 ~ L ~ R$:将 $p[L…R]$ 从大到小排序。

所有询问结束后,回答 $x$ 所在的index。

其中,$n,q \leq 2 \times 10^5, x \in [1,n]$。

题解

经典套路题。一般看见这种每次询问然后排序的,一般来讲只有在整个数组中 distinct 的数非常少的时候才能做。

这个题注意到我们只关心 $x$,于是我们可以把剩下的数分为比 $x$ 小和比 $x$ 大的。

而除了 $x$ 以外所有数之间的相对顺序我们并不关心。

所以我们可以直接开两棵线段树,线段树里只储存 $0$ 或者 $1$,代表这个位置是否存在比 $x$ 大/小 的数。

而 sort 的时候,我们就相当于把两棵线段树的 $[L,R]$ 中,看一下 $<x$ 和 $>x$ 分别有多少个,然后清空整个线段,然后重新进行区间赋值即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 4e5+5;
const int maxm = 1e5+5;

struct Tree_Node {
    int sum = 0;
    bool flag = 0;
    int val;  // 区间赋的值
};

struct Segment_Tree {
    Tree_Node tr[maxn<<2];
    void push_up(int cur) {
        tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
    }
    void push_down(int cur, int l, int r) {
        if (!tr[cur].flag) return;
        int val = tr[cur].val;
        tr[cur].flag = 0;

        int mid = (l+r) >> 1;
        int lc = cur<<1, rc = cur<<1|1;
        int llen = (mid - l + 1), rlen = (r - mid);
        tr[lc].flag = tr[rc].flag = 1;
        tr[lc].val = tr[rc].val = val;
        tr[lc].sum = llen * val, tr[rc].sum = rlen * val;
    }
    // 区间赋值为 x
    void update(int cur, int l, int r, int L, int R, int x) {
        if (L > R) return;
        if (L <= l && R >= r) {
            tr[cur].flag = 1;
            tr[cur].val = x;
            tr[cur].sum = (r-l+1) * x;
            return;
        }
        push_down(cur, l, r);
        int mid = (l+r) >> 1;
        if (L <= mid) update(cur<<1, l, mid, L, R, x);
        if (R > mid) update(cur<<1|1, mid+1, r, L, R, x);
        push_up(cur);
    }
    // 区间求和
    int query(int cur, int l, int r, int L, int R) {
        if (L <= l && R >= r) {
            return tr[cur].sum;
        }
        push_down(cur, l, r);
        int mid = (l+r) >> 1;
        int res = 0;
        if (L <= mid) res += query(cur<<1, l, mid, L, R);
        if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
        return res;
    }
} tr1, tr2;  // 比它小,比它大

int n, q, x, p;
int a[maxn];
int main() {
    cin >> n >> q >> x;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        if (a[i] == x) p = i;
    }
    for (int i = 1; i <= n; i++) {
        if (a[i] < x) tr1.update(1, 1, n, i, i, 1);
        if (a[i] > x) tr2.update(1, 1, n, i, i, 1);
    }
    while (q--) {
        int c, L, R; cin >> c >> L >> R;
        int sum1 = tr1.query(1, 1, n, L, R);
        int sum2 = tr2.query(1, 1, n, L, R);

        tr1.update(1, 1, n, L, R, 0);
        tr2.update(1, 1, n, L, R, 0);

        if (c == 1) {
            tr1.update(1, 1, n, L, L+sum1-1, 1);
            tr2.update(1, 1, n, R-sum2+1, R, 1);
            if (p >= L && p <= R) p = L + sum1;
        } else {
            tr2.update(1, 1, n, L, L+sum2-1, 1);
            tr1.update(1, 1, n, R-sum1+1, R, 1);
            if (p >= L && p <= R) p = L + sum2;
        }
    }
    cout << p << endl;
}

例8 Eerie Subarrays

题意

给定一个 $1$ 到 $n$ 的permutation,求有多少个 subarray 使得其最左边的元素是这个subarray的中位数?

其中,$n \leq 2 \times 10^5$。

题解

假设我们正在考虑的数是 $a_i$,那么只用考虑起点从 $i$ 开始的subarray。

将所有 $< a_i$ 的数设为 $-1$,所有 $> a_i$ 的数设为 $1$,那么问题转化为:

求 $R > i$ 的数量,使得 $sum[i…R] = 0$ (等价于使得 $sum[i-1] = sum[R]$)。

线段树没法统计这个,但可以用分块来统计!


同时注意到给定的是一个 $1$ 到 $n$ 的permutation,所以我们从 $1$ 开始,初始状态下所有值均为 $1$,而 $1$ 所在的位置为 $0$,之后考虑下一个数的时候就是将 $1$ 所在的位置设为 $-1$,$2$ 所在的位置设为 $0$,以此类推。

这样每次更新只需要更改两个位置的数,而更改一个位置的数会影响后面的所有的 $sum$ 值。对于块内的,暴力更新,对于后面的所有块,给块加上一个标记即可。所以每次更新复杂度是 $O(\sqrt n)$。


最后如何统计答案?对于块内,暴力统计,对于一个整块,我们维护块内每个元素出现的次数即可。

如何维护?显然不能对每个块都开 $1$ 到 $n$ 的数组。但注意到每个块里面最多 $B = \sqrt n$ 个元素,并且每个元素被暴力减的次数 $\leq 2B$。所以我们可以先将每个块的所有元素初始化到 $[2B, 3B]$ 的范围内(通过更新标记来初始化),然后每个块里面只要维护一个长度为 $3B$ 的vector即可,总空间复杂度就是 $O(n)$ 了。

代码
#include <bits/stdc++.h>
using namespace std;

const int maxn = 2e5+5;


int n, p[maxn], pos[maxn];
vector<int> b[maxn];
int f[maxn], B;

// [1,B] -> vector<int>, [B+1, 2B] -> vector<int>
int L[maxn], R[maxn];
int bel[maxn];
int bmax = 0;

vector<int> cnt[maxn];  // cnt[bi][i]: 代表 i 在 block(bi) 中的出现次数

// 将位置 x 的值减 1
void upd(int x) {
    int bi = bel[x];
    int l = L[bi];

    int ori = b[bi][x-l];
    b[bi][x - l] -= 1;
    
    cnt[bi][ori]--;
    cnt[bi][ori-1]++;
}

// 将位置 x 的值减 1, 说明所有 >= x 的值都要减 1
void updall(int x) {
    int bi = bel[x];
    int r = R[bi];
    for (int j = x; j <= r; j++) {
        upd(j);
    }
    for (int b = bi+1; b <= bmax; b++) {
        f[b] -= 1;
    }
}

// 回答 sum[x] 的值
int check(int x) {
    int bi = bel[x], l = L[bi];
    return f[bi] + b[bi][x-l];
}

// 询问有多少个 y >= x 使得 sum[y] = sum[x]
ll query(int x) {
    int sumx = check(x);
    int bi = bel[x], r = R[bi];
    ll res = 0;
    for (int j = x; j <= r; j++) {
        if (check(j) == sumx) res++;
    }
    for (int b = bi+1; b <= bmax; b++) {
        int flag = f[b];
        // 有多少个 cnt[y] 使得 y + flag == sumx -> y == sumx - flag
        if (sumx - flag >= 0 && sumx - flag < cnt[b].size()) res += cnt[b][sumx-flag];
    }
    return res;
}

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> p[i], pos[p[i]] = i;
    fill(L, L+maxn, 1e9);

    B = sqrt(n);
    for (int i = 1; i <= n; i++) {
        bel[i] = (i-1) / B + 1;
        int bi = bel[i];
        L[bi] = min(L[bi], i);
        R[bi] = max(R[bi], i);
        bmax = max(bmax, bi);
    }

    for (int i = 1; i <= bmax; i++) {
        int j = 3*B + 10;
        while (j--) cnt[i].push_back(0);
    }

    for (int i = 1; i <= n; i++) {
        int bi = bel[i];
        // 1: [1, B] -> flag = -2B, 2: [B+1, 2B] -> flag = 0, 3: flag = B
        int flag = (bi - 3) * B;
        f[bi] = flag;
        b[bi].push_back(i - flag);
        cnt[bi][i-flag]++;
    }

    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        updall(pos[i]);
        if (i > 1) updall(pos[i-1]);
        ans += query(pos[i]);
    }
    cout << ans << endl;
}