主要记录一些遇到的线段树/分块例题。

例1 CF438D

题意

给定 $N$ 个正整数和 $M$ 个询问,询问有 3 种:

$1 ~ l ~ r$:输出 $\sum\limits_{i=l}^r a_i$

$2 ~ l ~ r ~ x$:将 $a_l$ 到 $a_r$ 的所有数取 $\text{mod } x$

$3 ~ k ~ x$:将 $a_k = x$

其中,$1 \leq N,M \leq 10^5, 1 \leq a_i,x \leq 10^9$

题解

本题分块和线段树都可以做,我们这里用 线段树 来做。

主要是需要考虑 区间取模 怎么办?

回忆一下分块例题中的 区间开方,我们维护了一个额外的tag表示这个区间是否为 全0/全1,如果不是 全0/全1 就暴力开方。

取模操作同理,我们发现,如果 $a_i > x$,那么 $a_i \text{ mod } x \leq \frac{a_i}{2}$,所以对于每个 $a_i$,最多只会被 $\text{mod}$ $\log (a_i)$ 次!

所以,我们维护一个 区间最大值,取模时,检查一下 区间最大值是否大于 $x$

  1. 如果大于 $x$,就继续往下递归。
  2. 如果小于 $x$,就直接返回。

base case 就是区间长度为 $1$ 时,直接对这个元素开方即可。

例2 CF558E

题意

给定一个长度为 $n$ 的string,仅包含小写字母。给 $q$ 个询问:

$l,r,k$:将string的 $[l,r]$ 进行sort,如果 $k=1$ 就升序,$k=0$ 降序。

输出所有询问结束后的string。

其中,$1\leq n \leq 10^5, 1 \leq q\leq 50000$

题解

线段树来处理。

首先,string只包含小写字母。所以每个 node 可以维护一个 cnt[26] 代表这个node里的每个字母出现的次数。

其次,对于排序,我们在每个 node 中维护一个标记 $k$ 来代表该区间是否排序好了。若 $k=0$ 代表降序,$k=1$ 代表升序,$k = -1$ 代表乱序。

最后,维护一个 $lazy$ 标记,我们会注意到对于一个node而言,若 $lazy = 1$,那么这个 node 必然是排序好了的!(要么 $k=0$ ,要么 $k=1$)。

有了以上信息,我们就可以进行 sort 操作了!


sort $[L,R]$ 的时候,步骤如下:

  1. 提取出 $[L,R]$ 内每个字母出现的次数。
  2. 求出 $[L,R]$ 与 $[l,mid]$(当前node 左child的范围)的区间交集 $[l_1,r_1]$
  3. 求出 $[L,R]$ 与 $[mid+1, r]$(当前node 右child的范围)的区间交集 $[l_2,r_2]$
  4. 用指针遍历 $a-z$(或者 $z-a$),根据升序/降序 将 字母出现的次数分别填充左child和右childcnt[] 中。(注意,这里的填充是指:先填充进一个 int* buf = new int[26]; 的动态数组,然后buf[] 作为参数,再往下传递,直到区间完全覆盖,再将 buf[] 的内容复制进 cnt[] 里)。
代码
using namespace std;
#include <bits/stdc++.h>
const int maxn = 1e5+5;
const int maxm = 2e5+10;
 
int n,q;
char arr[maxn];
 
struct Node {
    int l,r,k,cnt[26];
    bool lazy = 0;
} tr[4*maxn];
int tmp[26];
 
inline int len(int cur) { return tr[cur].r - tr[cur].l + 1; }
 
void push_up(int cur) {
    int lc = cur<<1, rc = lc+1;
    for (int i = 0; i < 26; i++) {
        tr[cur].cnt[i] = tr[lc].cnt[i] + tr[rc].cnt[i];
    }
    if (tr[lc].k != -1 && tr[lc].k == tr[rc].k) tr[cur].k = tr[lc].k;  // k = 1: increasing, k = 0: decreasing
    else tr[cur].k = -1;
}
 
void put(int cur, int k) {
    int lc = cur<<1, rc = lc+1;
    memset(tr[lc].cnt, 0, sizeof(tr[lc].cnt));
    memset(tr[rc].cnt, 0, sizeof(tr[rc].cnt));
    memcpy(tmp, tr[cur].cnt, sizeof(tmp));
 
    int lsz = len(lc), rsz = len(rc);
    if (k) {
        int p = 0;
        while (p < 26 && lsz) {
            int delta = min(lsz, tmp[p]);
            tr[lc].cnt[p] += delta;
            lsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p++;
        }
        while (p < 26 && rsz) {
            int delta = min(rsz, tmp[p]);
            tr[rc].cnt[p] += delta;
            rsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p++;
        }
    } else {
        int p = 25;
        while (p >= 0 && lsz) {
            int delta = min(lsz, tmp[p]);
            tr[lc].cnt[p] += delta;
            lsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p--;
        }
        while (p >= 0 && rsz) {
            int delta = min(rsz, tmp[p]);
            tr[rc].cnt[p] += delta;
            rsz -= delta;
            tmp[p] -= delta;
            if (!tmp[p]) p--;
        }
    }
    assert(lsz == 0);
    assert(rsz == 0);
}
 
void push_down(int cur) {
    if (!tr[cur].lazy) return;
    int lc = cur<<1, rc = lc+1;
    tr[cur].lazy = 0;
    tr[lc].lazy = tr[rc].lazy = 1;
    assert(tr[cur].k != -1);
 
    int k = tr[cur].k;
    tr[lc].k = k;
    tr[rc].k = k;
    put(cur,k);
}
 
void build(int cur, int L, int R) {
    tr[cur].l = L, tr[cur].r = R;
    if (L == R) {
        memset(tr[cur].cnt, 0, sizeof(tr[cur].cnt));
        tr[cur].cnt[arr[L]-'a'] = 1;
        return;
    }
    int mid = (L+R) >> 1;
    build(cur<<1, L, mid);
    build(cur<<1|1, mid+1, R);
    push_up(cur);
}
 
int ress[26];  // 每次query的结果会存到这里
 
void clear(int* buf) {
    for (int i = 0; i < 26; i++) buf[i] = 0;
}
 
int inter(int l1, int r1, int l2, int r2) {  // 求区间交集的长度
    int l = max(l1,l2), r = min(r1,r2);
    return max(0,r-l+1);
}
 
void update(int cur, int L, int R, int k, int* res) {  // 注意参数里有个动态数组 res
    int lc = cur<<1, rc = lc+1;
    int l = tr[cur].l, r = tr[cur].r;
    if (l >= L && r <= R) {
        tr[cur].k = k;
        tr[cur].lazy = 1;
        for (int i = 0; i < 26; i++) tr[cur].cnt[i] = res[i];  // 区间完全覆盖,复制到 cnt 中。
        clear(res);  // 记得清空,之后可能还要用
        return;
    }
    int mid = (l+r) >> 1;
    int lsz = inter(l,mid,L,R), rsz = inter(mid+1,r,L,R);
 
    int* buf = new int[26];  // 这里采用了动态数组
    for (int i = 0; i < 26; i++) buf[i] = 0;  //注意new出来的需要先清空一下,另外不能使用 memset(因为是指针)
 
    if (k) {
        int p = 0;
        while (p < 26 && lsz) {
            int delta = min(lsz, res[p]);
            buf[p] += delta;
            lsz -= delta;
            res[p] -= delta;
            if (!res[p]) p++;
        }
        if (L <= mid) {
            update(lc, L, R, k, buf);  // 传递 buf
        }
 
        while (p < 26 && rsz) {
            int delta = min(rsz, res[p]);
            buf[p] += delta;
            rsz -= delta;
            res[p] -= delta;
            if (!res[p]) p++;
        }
 
        if (R > mid) {
            update(rc, L, R, k, buf);  // 传递 buf
        }
 
    } else {
        int p = 25;
        while (p >= 0 && lsz) {
            int delta = min(lsz, res[p]);
            buf[p] += delta;
            lsz -= delta;
            res[p] -= delta;
            if (!res[p]) p--;
        }
        if (L <= mid) {
            update(lc, L, R, k, buf);  // 传递 buf
        }
        while (p >= 0 && rsz) {
            int delta = min(rsz, res[p]);
            buf[p] += delta;
            rsz -= delta;
            res[p] -= delta;
            if (!res[p]) p--;
        }
        if (R > mid) {
            update(rc, L, R, k, buf);  // 传递 buf
        }
    }
    delete[] buf;
 
    assert(lsz == 0);
    assert(rsz == 0);
    push_up(cur);
}
 
void query(int cur, int L, int R) {  // 求区间内每个字母出现的个数
    int l = tr[cur].l, r = tr[cur].r;
    if (l >= L && r <= R) {
        for (int i = 0; i < 26; i++) ress[i] += tr[cur].cnt[i];
        return;
    }
 
    int lc = cur<<1, rc = lc+1;
    push_down(cur);
    int mid = (l+r) >> 1;
    if (L <= mid) query(lc, L, R);
    if (R > mid) query(rc, L, R);
    push_up(cur);
}
 
void printans() {
    for (int i = 1; i <= n; i++) {
        memset(ress, 0, sizeof(ress));
        query(1,i,i);
        for (int j = 0; j < 26; j++) {
            if (ress[j]) {
                printf("%c",(char)('a'+j));
                ress[j]--;
                break;
            }
        }
    }
    printf("\n");
}
 
int main() {
 
    scanf("%d%d",&n,&q);
    scanf("%s", arr+1);
    build(1, 1, n);
    while (q--) {
        int l,r,k; scanf("%d%d%d",&l,&r,&k);
        memset(ress, 0, sizeof(ress));
        query(1,l,r);
        update(1,l,r,k,ress);
    }
    printans();
}

例3 洛谷P1972 HH的项链

题意

给定一个长度为 $N$ 的数组 $a_1,a_2,…,a_n$,以及 $m$ 个询问,每次询问 $[L,R]$ 之间有多少个不同的数。

其中,$1 \leq n,m,a_i \leq 10^6$

题解

我们可以发现,如果我们固定了询问的右端点 $R$,那么无论 $L$ 为多少,在 $R$ 的左侧的所有重复数字中,仅保留最靠右的一个 copy 即可。

例如 $arr = 1,3,2,1,7,1$,那么我们在遍历到 $i = 4$ 时,我们仅需要保留最后一个 $1$ (也就是 index 为 $4$ 的数字)。

由上,在处理 区间内不同的数 时,一个常见的套路是:

  1. 离线处理所有询问,按右端点 $R$ 排序。
  2. 从左到右遍历数组,遍历到 $i$ 时,对于所有 $a_i$ 的 copy,仅保留最靠右的那一个(也就是 $i$),之前的所有 copy 全部删除。
  3. 回答所有 $[L, i]$ 的询问。

那么对于本题,第一步是离线处理询问,按右端点 $R$ 排序。

第二步是遍历数组,遍历过程中维护一个 int pos[] 数组,其中 pos[val] 代表:在 $arr[1…i]$ 中,值为 val 的数 最靠右的 index

当我们遍历到 $i$ 时,令 int val = arr[i],将 pos[val] 处的数字 删掉(更新线段树),然后将 $i$ 处的数字 加入线段树,最后更新一下 pos[val] = i;

•本题中,删掉就是将 线段树的位置 $i$ 的值 减去$1$,加上就是将 线段树的位置 $i$ 的值 加上$1$。

然后回答所有 以 $i$ 为右端点的询问(求 $[L,i]$ 的和即可)。

注:如果询问 在线 怎么办?可以使用主席树!(对于 pr[i] 建主席树)

具体做法参考:CF1422F Boring Queries

代码
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 1e6+5;

int n, m, arr[maxn];
struct node {
    int sum;
} tr[maxn<<2];

void push_up(int cur) {
    tr[cur].sum = tr[cur<<1].sum + tr[cur<<1|1].sum;
}

void update(int cur, int l, int r, int p, int x) {
    if (l == r) {
        tr[cur].sum += x; return;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) update(cur<<1, l, mid, p, x);
    else update(cur<<1|1, mid+1, r, p, x);
    push_up(cur);
} 

int query(int cur, int l, int r, int L, int R) {
    if (l >= L && r <= R) return tr[cur].sum;
    int mid = (l+r) >> 1;
    int res = 0;
    if (L <= mid) res += query(cur<<1, l, mid, L, R);
    if (R > mid) res += query(cur<<1|1, mid+1, r, L, R);
    return res;
}

int ans[maxn];
int pos[maxn];
struct Query {
    int id,l,r;
} q[maxn];

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    cin >> m;
    for (int i = 1; i <= m; i++) {
        int l,r;
        cin >> l >> r;
        q[i] = {i,l,r};
    }
    
    sort(q+1,q+m+1, [](auto a, auto b) {
        return a.r < b.r;  // 根据右端点离线
    });
    int ptr = 1;
    for (int i = 1; i <= n; i++) {
        int val = arr[i];
        if (pos[val]) update(1, 1, n, pos[val], -1);  // 删去 pos[val]
        update(1, 1, n, i, 1);  // 加上 i 

        pos[val] = i;  // 更新 pos[val]
        while (ptr <= m && q[ptr].r == i) {  // 回答所有以 i 作为右端点的询问
            int id = q[ptr].id, L = q[ptr].l, R = q[ptr].r;
            ans[id] = query(1, 1, n, L, R);
            ptr++;
        }
        if (ptr > m) break;
    }
    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例4 CF1000F One Occurrence

题意

给定一个长度为 $n$ 的数组,$m$ 个询问,每次询问一个区间 $[l,r]$,如果这个区间里存在只出现一次的数,输出这个数(如果有多个就输出任意一个),没有就输出 $0$。

其中 $n,m \leq 5 \times 10^5$。

题解

和例3类似的套路,也是离线处理询问(根据右端点sort),从左往右遍历,仅保留最靠右的复制。

问题在于怎么删除 和 加入数字?因为本题不再是求数量了,所以不能简单的加 $1$ 或者 减 $1$。

会发现,我们仅关心一个区间内是否存在 unique 的数字,对于一个询问 $[L,R]$ 内,我们只要看,是否存在 $i$ 使得 $arr[i]$ 的前一个复制 不在 $[L,R]$ 内。(也就是说,pos[val] 是否小于 $L$)

那么,我们用线段树维护一下 区间最小值 即可,其中 $[L,R]$ 的区间最小值就代表着所有$i \in [L,R]$ 中,pos[arr[i]] 的最小值。如果一个区间 $[L,R]$ 的最小值 $\geq L$,那么答案不存在,否则答案存在。

那么,删除位置 $i$ 就是将它这个位置上的值设为 $inf$。

加入位置 $i$ 的数,就是将它这个位置上的值设为 pos[arr[i]]

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;

struct node {
    int pre = 1e9, idx = 0;  // pre 代表 pos[val] 的值,idx代表这个最小值对应的 index
} tr[maxn<<2];
int pos[maxn], n, m, arr[maxn];
struct Query {
    int id,l,r;
} q[maxn];

void push_up(int cur) {
    int lpre = tr[cur<<1].pre, rpre = tr[cur<<1|1].pre;
    if (lpre < rpre) tr[cur].pre = lpre, tr[cur].idx = tr[cur<<1].idx;
    else tr[cur].pre = rpre, tr[cur].idx = tr[cur<<1|1].idx;
}

void update(int cur, int l, int r, int p, int x) {
    if (l == r) {
        tr[cur].pre = x, tr[cur].idx = l;
        return;
    }
    int mid = (l+r) >> 1;
    if (p <= mid) update(cur<<1, l, mid, p, x);
    else update(cur<<1|1, mid+1, r, p, x);
    push_up(cur);
}

pii query(int cur, int l, int r, int L, int R) {
    if (l >= L && r <= R) return {tr[cur].pre, tr[cur].idx};
    int mid = (l+r) >> 1;
    pii r1 = {1e9, 0}, r2 = {1e9, 0};
    if (L <= mid) r1 = query(cur<<1, l, mid, L, R);
    if (R > mid) r2 = query(cur<<1|1, mid+1, r, L, R);
    if (r1.first > r2.first) return r2;
    else return r1;
}
int ans[maxn];

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> arr[i];
    cin >> m;
    for (int i = 1; i <= m; i++) {
        int l,r; cin >> l >> r;
        q[i] = {i,l,r};
    }
    sort(q+1, q+m+1, [](auto a, auto b) {
        return a.r < b.r;
    });

    int ptr = 1;
    for (int i = 1; i <= n; i++) {
        int val = arr[i];
        update(1, 1, n, i, pos[val]);  // 在i处 加入数字 pos[val]
        if (pos[val]) update(1, 1, n, pos[val], 1e9);  // 在pos[val] 处删除数字(设为 inf)
        pos[val] = i;
        while (ptr <= m && q[ptr].r == i) {
            int L = q[ptr].l, R = q[ptr].r;
            pii res = query(1, 1, n, L, R);
            int id = q[ptr].id;

            if (res.first >= L) ans[id] = 0;  // 如果这个区间内,所有 pos[val] 都 >= L
            else ans[id] = arr[res.second];  //  否则,答案存在
            ptr++;
        }
        if (ptr > m) break;
    }
    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}