吉司机线段树(Segment Tree Beats)
Contents
介绍
吉司机线段树可以做到:
- 区间最大/最小操作(对一个区间内的所有数取 max 或者 min)
- 维护区间历史最值
我们直接看一道例题:
题意
给定一个数组 $a_1, a_2, … a_n$,同时给定一个数组 $b_1, b_2…, b_n$(初始状态下 $a,b$ 相同)。
进行 $m$ 次操作,操作有 $5$ 种:
$1 ~ l ~ r ~ x$:$\forall i \in [l,r]$,将 $a_i$ 加上 $x$。
$2 ~ l ~ r ~ x$:$\forall i \in [l,r]$,将 $a_i$ 变成 $\min(a_i, x)$。
$3 ~ l ~ r$:求 $[l,r]$ 之间 $a_i$ 的区间和。
$4 ~ l ~ r$:求 $[l,r]$ 之间 $a_i$ 的区间最大值。
$5 ~ l ~ r$:求 $[l,r]$ 之间 $b_i$ 的区间最大值。
在每一次操作后,我们进行一次更新,使得 $\forall i \in [1,n], b_i \leftarrow \max(b_i, a_i)$。
吉司机线段树可以在 $O((n+m) \log^2n)$ 的时间内解决这个问题。
具体的,我们先看一下线段树的节点需要维护什么。
区间历史最值
先考虑第 $5$ 个操作。首先 $b_i$ 其实是一个历史最大值数组(即 $a_i$ 在任意时刻,所存放过的最大的值)。我们要求的是 $b_i$ 的区间最大值。
在区间加的时候,我们会维护一个最大值的懒标记 add1
。不过这个懒标记可能会被多次更新,所以我们只要维护在任意时刻,这个懒标记的最大值 add3
(也就是懒标记本身的 历史最大值)即可。
• 简而言之 add3
就是 add1
在任意时刻所达到的最大值。
区间取最小操作
区间怎么取最小?考虑一下,如果我们要对一个区间取 $\min$,要取的值是 $x$。
那么假设这个区间只有 一种值 是 $>x$ 的,那么就可以取最小了。
所以我们需要维护区间的最大值 maxa
,和次大值 se
。一个区间可以取 min
当且仅当 $se < x < maxa$。
如果不满足这个条件,则要么直接退出($x \geq maxa$),或者继续递归($x < maxa$)。
由于我们需要维护次大值,而区间加也有可能会更新次大值,所以我们还需要 非最大值的懒标记 add2
,同上理由还需要一个 非最大值懒标记 的历史最大值 add4
。
• 简而言之 add4
就是 add2
在任意时刻所达到的最大值。
模版
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
struct SegTreeBeats {
const int INF = 1e9;
int a[maxn]; // 原数组
struct Node {
ll sum;
int maxa, maxb, cnt, se;
// maxa: 区间最大值, maxb: 区间历史最大值, cnt: 区间最大值的数量, se: 区间严格次大值
int add1, add2, add3, add4;
// add1: 区间最大值懒标记, add2: 区间非最大值懒标记, add3: 区间历史最大值懒标记, add4: 区间历史非最大值懒标记
} tr[maxn<<2];
void push_up(int cur) {
int lc = cur<<1, rc = cur<<1|1;
tr[cur].sum = tr[lc].sum + tr[rc].sum;
tr[cur].maxa = max(tr[lc].maxa, tr[rc].maxa);
tr[cur].maxb = max(tr[lc].maxb, tr[rc].maxb);
if (tr[lc].maxa == tr[rc].maxa) {
tr[cur].se = max(tr[lc].se, tr[rc].se);
tr[cur].cnt = tr[lc].cnt + tr[rc].cnt;
} else if (tr[lc].maxa > tr[rc].maxa) {
tr[cur].se = max(tr[lc].se, tr[rc].maxa);
tr[cur].cnt = tr[lc].cnt;
} else { // rc.maxa > lc.maxa
tr[cur].se = max(tr[lc].maxa, tr[rc].se);
tr[cur].cnt = tr[rc].cnt;
}
}
void helper(int cur, int l, int r, ll k1, ll k2, ll k3, ll k4) {
tr[cur].sum += k1 * tr[cur].cnt + k2 * (r-l+1-tr[cur].cnt);
tr[cur].maxb = max((ll)tr[cur].maxb, tr[cur].maxa + k3);
tr[cur].maxa += k1;
if (tr[cur].se != -INF) tr[cur].se += k2;
tr[cur].add3 = max((ll)tr[cur].add3, tr[cur].add1 + k3);
tr[cur].add4 = max((ll)tr[cur].add4, tr[cur].add2 + k4);
tr[cur].add1 += k1, tr[cur].add2 += k2;
}
void push_down(int cur, int l, int r) {
int lc = cur<<1, rc = cur<<1|1;
int mx = max(tr[lc].maxa, tr[rc].maxa);
int mid = (l+r) >> 1;
if (tr[lc].maxa == mx) helper(lc, l, mid, tr[cur].add1, tr[cur].add2, tr[cur].add3, tr[cur].add4); // 注意这里是 1,2,3,4
else helper(lc, l, mid, tr[cur].add2, tr[cur].add2, tr[cur].add4, tr[cur].add4); // 注意这里是 2,2,4,4
if (tr[rc].maxa == mx) helper(rc, mid+1, r, tr[cur].add1, tr[cur].add2, tr[cur].add3, tr[cur].add4); // 注意这里是 1,2,3,4
else helper(rc, mid+1, r, tr[cur].add2, tr[cur].add2, tr[cur].add4, tr[cur].add4); // 注意这里是 2,2,4,4
tr[cur].add1 = tr[cur].add2 = tr[cur].add3 = tr[cur].add4 = 0;
}
void build(int cur, int l, int r) {
if (l == r) {
tr[cur].sum = tr[cur].maxa = tr[cur].maxb = a[l];
tr[cur].cnt = 1;
tr[cur].se = -INF;
return;
}
int mid = (l+r) >> 1;
build(cur<<1, l, mid);
build(cur<<1|1, mid+1, r);
push_up(cur);
}
ll query_sum(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].sum;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
ll res = 0;
if (L <= mid) res += query_sum(cur<<1, l, mid, L, R);
if (R > mid) res += query_sum(cur<<1|1, mid+1, r, L, R);
return res;
}
int query_maxa(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].maxa;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = -INF;
if (L <= mid) res = max(res, query_maxa(cur<<1, l, mid, L, R));
if (R > mid) res = max(res, query_maxa(cur<<1|1, mid+1, r, L, R));
return res;
}
int query_maxb(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].maxb;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = -INF;
if (L <= mid) res = max(res, query_maxb(cur<<1, l, mid, L, R));
if (R > mid) res = max(res, query_maxb(cur<<1|1, mid+1, r, L, R));
return res;
}
void update_add(int cur, int l, int r, int L, int R, ll x) {
if (L <= l && R >= r) {
tr[cur].sum += x * (r-l+1);
tr[cur].maxa += x;
tr[cur].maxb = max(tr[cur].maxb, tr[cur].maxa);
if (tr[cur].se != -INF) tr[cur].se += x;
tr[cur].add1 += x, tr[cur].add2 += x;
tr[cur].add3 = max(tr[cur].add3, tr[cur].add1);
tr[cur].add4 = max(tr[cur].add4, tr[cur].add2);
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update_add(cur<<1, l, mid, L, R, x);
if (R > mid) update_add(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
void update_min(int cur, int l, int r, int L, int R, int x) {
if (x >= tr[cur].maxa) return;
if (L <= l && R >= r && x > tr[cur].se) { // 保证 se > x
ll k = tr[cur].maxa - x; // 最大值减少的幅度
tr[cur].sum -= tr[cur].cnt * k;
tr[cur].maxa = x;
tr[cur].add1 -= k;
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update_min(cur<<1, l, mid, L, R, x);
if (R > mid) update_min(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
} seg;
int n, m;
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> seg.a[i];
seg.build(1, 1, n);
while (m--) {
int op, l, r; cin >> op >> l >> r;
if (op == 1) {
ll x; cin >> x;
seg.update_add(1, 1, n, l, r, x);
} else if (op == 2) {
ll x; cin >> x;
seg.update_min(1, 1, n, l, r, x);
} else if (op == 3) {
cout << seg.query_sum(1, 1, n, l, r) << "\n";
} else if (op == 4) {
cout << seg.query_maxa(1, 1, n, l, r) << "\n";
} else {
cout << seg.query_maxb(1, 1, n, l, r) << "\n";
}
}
}
例题
例1 洛谷P4314 CPU 监控
题意
给定一个数组 $a_1,a_2…,a_n$,有 $m$ 个询问。
进行 $m$ 次操作,操作有 $4$ 种:
$Q ~ l ~ r$:求 $[l,r]$ 之间 $a_i$ 的区间最大值。
$A ~ l ~ r$:求 $[l,r]$ 之间 $a_i$ 的区间历史最大值。
$P ~ l ~ r ~ x$:将 $[l,r]$ 之间的 $a_i$ 增加 $x$。
$C ~ l ~ r ~ x$:将 $[l,r]$ 之间的 $a_i$ 变为 $x$。
题解
题解链接:https://www.luogu.com.cn/blog/He-Ren/solution-p4314
带有区间赋值,维护历史最大值。
我们对于一个区间,在 push_down()
之后,所有之前的操作都可以看作不存在。
那么我们按照每次 push_down()
作为分隔符来考虑这些操作序列,以下的讨论都仅限于 一次 push_down()
以内的。
一个区间的操作只有区间加和区间赋值,并且相邻的相同类型操作可以合并为一个。
所以最终操作序列肯定可以被简化为 区间加 + 区间赋值 + 区间加 + 区间赋值 + 区间加 …
我们可以发现,区间赋值后,所有的区间加操作都可以被重新看作为赋值,所以我们维护一个 set_tag
表示在此次 push_down()
内是否有赋值操作即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
struct SegTreeBeats {
const int INF = 2e9;
int a[maxn]; // 原数组
struct Node {
ll sum;
int maxa, maxb;
// maxa: 区间最大值, maxb: 区间历史最大值, cnt: 区间最大值的数量, se: 区间严格次大值
int add1, add2;
int set1, set2;
bool set_tag;
// add1: 区间最大值懒标记, add2: 区间非最大值懒标记, add2: 区间历史最大值懒标记, add4: 区间历史非最大值懒标记
} tr[maxn<<2];
void push_up(int cur) {
int lc = cur<<1, rc = cur<<1|1;
tr[cur].maxa = max(tr[lc].maxa, tr[rc].maxa);
tr[cur].maxb = max(tr[lc].maxb, tr[rc].maxb);
}
void add_helper(int cur, ll x, ll max_add) {
if (tr[cur].set_tag) {
tr[cur].set2 = max((ll)tr[cur].set2, tr[cur].set1 + max_add);
tr[cur].maxb = max((ll)tr[cur].maxb, tr[cur].maxa + max_add);
tr[cur].maxa += x;
tr[cur].set1 += x;
} else {
tr[cur].add2 = max((ll)tr[cur].add2, tr[cur].add1 + max_add);
tr[cur].maxb = max((ll)tr[cur].maxb, tr[cur].maxa + max_add);
tr[cur].maxa += x;
tr[cur].add1 += x;
}
}
void set_helper(int cur, ll x, ll max_set) {
tr[cur].set2 = max((ll)(tr[cur].set2), max_set);
tr[cur].maxb = max((ll)tr[cur].maxb, max_set);
tr[cur].set_tag = 1;
tr[cur].maxa = tr[cur].set1 = x;
}
void push_down(int cur, int l, int r) {
int lc = cur<<1, rc = cur<<1|1;
add_helper(lc, tr[cur].add1, tr[cur].add2);
add_helper(rc, tr[cur].add1, tr[cur].add2);
tr[cur].add1 = tr[cur].add2 = 0;
if (tr[cur].set_tag) {
set_helper(lc, tr[cur].set1, tr[cur].set2);
set_helper(rc, tr[cur].set1, tr[cur].set2);
tr[cur].set_tag = 0;
tr[cur].set1 = tr[cur].set2 = 0;
return;
}
}
void build(int cur, int l, int r) {
if (l == r) {
tr[cur].sum = tr[cur].maxa = tr[cur].maxb = a[l];
return;
}
int mid = (l+r) >> 1;
build(cur<<1, l, mid);
build(cur<<1|1, mid+1, r);
push_up(cur);
}
ll query_sum(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].sum;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
ll res = 0;
if (L <= mid) res += query_sum(cur<<1, l, mid, L, R);
if (R > mid) res += query_sum(cur<<1|1, mid+1, r, L, R);
return res;
}
int query_maxa(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].maxa;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = -INF;
if (L <= mid) res = max(res, query_maxa(cur<<1, l, mid, L, R));
if (R > mid) res = max(res, query_maxa(cur<<1|1, mid+1, r, L, R));
return res;
}
int query_maxb(int cur, int l, int r, int L, int R) {
if (L <= l && R >= r) {
return tr[cur].maxb;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
int res = -INF;
if (L <= mid) res = max(res, query_maxb(cur<<1, l, mid, L, R));
if (R > mid) res = max(res, query_maxb(cur<<1|1, mid+1, r, L, R));
return res;
}
void update_set(int cur, int l, int r, int L, int R, ll x) {
if (L <= l && R >= r) {
set_helper(cur, x, x);
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update_set(cur<<1, l, mid, L, R, x);
if (R > mid) update_set(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
void update_add(int cur, int l, int r, int L, int R, ll x) {
if (L <= l && R >= r) {
add_helper(cur, x, x);
return;
}
push_down(cur, l, r);
int mid = (l+r) >> 1;
if (L <= mid) update_add(cur<<1, l, mid, L, R, x);
if (R > mid) update_add(cur<<1|1, mid+1, r, L, R, x);
push_up(cur);
}
} seg;
int n, m;
int main() {
cin >> n;
for (int i = 1; i <= n; i++) cin >> seg.a[i];
seg.build(1, 1, n);
cin >> m;
while (m--) {
char op; int l, r;
cin >> op >> l >> r;
if (op == 'Q') {
cout << seg.query_maxa(1, 1, n, l, r) << "\n";
} else if (op == 'A') {
cout << seg.query_maxb(1, 1, n, l, r) << "\n";
} else if (op == 'P') {
ll x; cin >> x;
seg.update_add(1, 1, n, l, r, x);
} else if (op == 'C') {
ll x; cin >> x;
seg.update_set(1, 1, n, l, r, x);
}
}
}