介绍

线段树分治一般用于处理一些 “只存在/不存在 权值为 $w$“,或者 ”某个特定时间点” 对应状态的问题。

原理是在线段树上进行 DFS,进入 DFS 时加入节点上的所有信息,回溯时则删除这些信息,当达到了叶子节点时就可以得到对应答案。

• 由于有撤销操作,所以题目一般得是支持撤销的才行。

经典例题

题意

有一个 $n$ 个节点的无向图,现在给定一个时间轴 $1…k$,有 $m$ 条边,每条边 $(u,v)$ 将会在时间段 $[l,r]$ 出现,其他时间消失。

现在对于每一个时间点,都回答:在这个时间点时,图是否为二分图。

其中,$n,k \leq 10^5, m \leq 2 \times 10^5, [l,r] \in [1,k]$。

首先注意到,可以利用权值并查集来动态维护是否为二分图。

然后我们构造一个线段树,当我们有一条在 $[l,r]$ 时间段的边 $(u,v)$ 时,就将这条边塞到 线段树的 $[l,r]$ 区间里,这样每个线段树节点维护一个 vector<Edge> 来保存这个时间段,存在哪些边。

然后对线段树进行 DFS,进入一个节点时就将它上面的所有边加入并查集,回溯时就将这些边从并查集删除(所以需要可撤销并查集)。

• 线段树其实就是一个容器,用来储存 Edge 的,本质上并没有维护 “区间信息”(有点类似 这道题 中线段树的作用),所以也没有 push_up, push_down

• 当我们检测到已经不是二分图的时候就可以停止继续往里面 DFS,直接回溯即可,这样也避免了一些复杂的问题,比如非二分图撤销边以后如何判断是不是二分图。

复杂度:每条边对应的时间段 $[l,r]$ 至多会被拆成 $O(\log k)$ 个线段树区间,而每条边都只会被加入和删除一次。而可撤销并查集的复杂度是 $O(\log n)$ 的,所以总复杂度为 $O(n \log n \log k)$。

例题

例1 洛谷P5787 二分图 /【模板】线段树分治

题意

如上。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 6e5+55;

struct State {
    int u, v, szu, szv;
} st[maxm];  // 注意这里是 maxm,应该是边的数量
bool ok = 1;
int n, m, k;
struct DSU {
    int par[maxn], sz[maxn], tail = 0;
    inline void init() {
        for (int i = 1; i <= n*2; i++) par[i] = i, sz[i] = 1;
        tail = 0;
    }
    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);  // 无路径压缩
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u] < sz[v]) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};   // 无论是否 union 成功都要push到 stack 里
        if (u == v) return;
        par[v] = u;
        sz[u] += sz[v];

    }
    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            par[v] = v;
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            tail--;
        }
    }
} dsu;

struct Edge {
    int u, v;
};
struct TreeNode {
    vector<Edge> vec;
};
bool ans[maxn];
int bad = 0;
struct SegmentTree {
    TreeNode tr[maxn<<2];
    void insert(int cur, int l, int r, int L, int R, Edge e) {
        if (L <= l && R >= r) {
            tr[cur].vec.push_back(e);
            return;
        }
        int mid = (l+r) >> 1;
        if (L <= mid) insert(cur<<1, l, mid, L, R, e);
        if (R > mid) insert(cur<<1|1, mid+1, r, L, R, e);
    }
    void dfs(int cur, int l, int r) {
        for (Edge& e : tr[cur].vec) {
            if (dsu.finds(e.u) == dsu.finds(e.v)) bad++;
            dsu.unions(e.u, e.v + n);
            dsu.unions(e.u + n, e.v);
        }

        if (bad) {
            for (int i = l; i <= r; i++) ans[l] = 0;
            bad = 0;
        } else {
            if (l == r) {
                ans[l] = 1;
            } else {
                int mid = (l+r) >> 1;
                dfs(cur<<1, l, mid);
                dfs(cur<<1|1, mid+1, r);
            }
        }

        for (Edge& e : tr[cur].vec) {
            dsu.cancel();
            dsu.cancel();
        }
    }
} tr;

int main() {
    fastio;
    cin >> n >> m >> k;
    dsu.init();
    for (int i = 1; i <= m; i++) {
        int u, v, l, r; cin >> u >> v >> l >> r;
        l++;
        if (l <= r) {
            tr.insert(1, 1, k, l, r, {u,v});
        }
    }
    tr.dfs(1, 1, k);
    for (int i = 1; i <= k; i++) cout << (ans[i] ? "Yes" : "No") << "\n";
}

例2 洛谷P5631 最小mex生成树

题意

给定一个带权的无向联通图。

求这个图的生成树,使得边权集合的 mex 最小,输出最小mex即可。

其中,$n \leq 10^6, m \leq 2 \times 10^6, w \in [0,10^5]$。

题解

求出最小 mex 只需要对于每一个 $w$,将权值等于 $w$ 的边全部去掉,剩下的都加上,看是否能形成生成树即可。

上一个例题中,线段树分治用于解决 “只包含权值为 $w$ 的边”,这个题我们可以反过来。例如有一条权值为 $3$ 的边,那么我们将它加入 $[0,2] \cup [4,10^5]$ 这两个区间内,然后再跑线段树分治,在 $3$ 的时候就代表 “仅不包含权值为 $w$ 的边” 了。


但是注意到这个方法 不能用于求最大 mex 生成树

比如有三条边的权值分别为 $0,0,1$,任取两条边可以组成生成树,那么首先我们在去掉任意 $\geq 1$ 的边后都能形成生成树,答案看起来是 $1$,但实际上应该选权值为 $0,1$ 的两条边,答案可以得到 $2$。

简单来说,去掉一个特定权值 $w$,能得到生成树,并不代表这个生成树的 mex 就等于 $w$。

为什么最小mex就可以呢?

假如我们去掉一个特定权值 $w$,能得到生成树,但实际上这个生成树的 mex 不等于 $w$,那么真实的 mex $m$ 一定是 $< w$ 的,那么我们在继续判断更小的权值时,一定会判断到真实的 mex $m$(因为线段树分治本质上是暴力枚举)。枚举到 $m$ 时也肯定能组成一个生成树,所以答案为 $m$。

代码
#include <bits/stdc++.h>
using namespace std;

const int maxn = 1e6+5;
const int maxm = 4e6+55;

struct State {
    int u, v, szu, szv;
} st[maxm];  // 注意这里是 maxm,应该是边的数量
int uni = 0;  // union 的次数,如果等于 n-1 说明联通了
int n, m, k;
struct DSU {
    int par[maxn], sz[maxn], tail = 0;
    inline void init() {
        for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1;
        tail = 0;
    }
    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);  // 无路径压缩
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u] < sz[v]) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};   // 无论是否 union 成功都要push到 stack 里
        if (u == v) return;
        par[v] = u;
        sz[u] += sz[v];
        uni++;  // 成功union
    }
    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            par[v] = v;
            if (sz[u] != st[tail].szu) uni--;  // 成功回退一次union
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            tail--;
        }
    }
} dsu;

struct Edge {
    int u, v, w;
};
struct TreeNode {
    vector<Edge> vec;
};

int ans = 1e5;
struct SegmentTree {
    TreeNode tr[(100005)<<2];
    void insert(int cur, int l, int r, int L, int R, Edge e) {
        if (L <= l && R >= r) {
            tr[cur].vec.push_back(e);
            return;
        }
        int mid = (l+r) >> 1;
        if (L <= mid) insert(cur<<1, l, mid, L, R, e);
        if (R > mid) insert(cur<<1|1, mid+1, r, L, R, e);
    }
    void dfs(int cur, int l, int r) {
        for (Edge& e : tr[cur].vec) {
            dsu.unions(e.u, e.v);
        }

        if (l == r) {
            if (uni == n-1) ans = min(ans, l);
        } else {
            int mid = (l+r) >> 1;
            dfs(cur<<1, l, mid);
            dfs(cur<<1|1, mid+1, r);
        }

        for (Edge& e : tr[cur].vec) {
            dsu.cancel();
        }
    }
} tr;

int main() {
    fastio;
    cin >> n >> m;
    dsu.init();
    int N = 1e5;
    for (int i = 1; i <= m; i++) {
        int u, v, w; cin >> u >> v >> w;
        if (w > 0)
            tr.insert(1, 0, N, 0, w-1, {u, v, w});
        if (w < N)
            tr.insert(1, 0, N, w+1, N, {u, v, w});
    }
    tr.dfs(1, 0, N);
    cout << ans << "\n";
}

例3 CF1140F. Extending Set of Points

题意

给定一个包含二维平面中的点的集合 $S$,定义操作 $E(S)$ 为:

初始定义 $R=S$,如果对于任何的 $x_1,y_1,x_2,y_2$,如果 $(x_1,y_1) \in R, (x_1,y_2) \in R, (x_2,y_1) \in R$ 且 $(x_2,y_2) \notin R$,则将 $(x_2,y_2)$ 加入到 $R$,直到这种操作无法继续进行。

定义 $E(S)$ 为上述操作后得到的 $R$ 的大小(这个操作不影响 $S$ 本身)。

现在给定 $n$ 个询问,每次询问为 $(x,y)$,如果这个点不在 $S$ 内,将其加入 $S$。否则将其从 $S$ 中删除。

每次询问后,回答 $E(S)$ 的大小。

其中,$n \leq 3 \times 10^5, x,y \in [1,3\times 10^5]$。

题解

经典套路:将 $x,y$ 坐标轴看作二分图里面的点,将二维平面上的点看作二分图里的边。

然后就能发现 $E(S)$ 本质上是计算了二分图里所有联通块的 左点数量 乘以 右点数量 的和。

现在问题来了,怎么在增加/删除边的情况下,动态维护并查集?

没有办法!

但注意到可以把 $n$ 个询问看作一个时间从 $1$ 到 $n$ 的时间轴,离线处理询问后,可以知道每条边所在的时间区间。那么这个问题就跟 例1 几乎一模一样了。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 6e5+5;

const int M = 3e5;
struct State {
    int u, v;
    pll szu, szv;
} st[maxn];  // 注意这里是 maxm,应该是边的数量
ll res = 0;
ll ans[maxn];
struct DSU {
    int par[maxn], tail = 0;
    pll sz[maxn];
    inline void init() {
        for (int i = 1; i <= M*2; i++) {
            par[i] = i;
            if (i <= M) sz[i] = {1, 0};
            else sz[i] = {0, 1};
        }
        tail = 0;
    }
    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);  // 无路径压缩
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u].first + sz[u].second < sz[v].first + sz[v].second) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};   // 无论是否 union 成功都要push到 stack 里
        if (u == v) return;
        par[v] = u;
        res -= sz[u].first * sz[u].second;
        res -= sz[v].first * sz[v].second;
        sz[u].first += sz[v].first;
        sz[u].second += sz[v].second;
        res += sz[u].first * sz[u].second;
    }
    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            bool f = 0;
            if (sz[u] != st[tail].szu) {
                res -= sz[u].first * sz[u].second;
                f = 1;
            }
            par[v] = v;
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            if (f) {
                res += sz[u].first * sz[u].second;
                res += sz[v].first * sz[v].second;
            }
            tail--;
        }
    }
} dsu;

struct Edge {
    int u, v;
};
struct TreeNode {
    vector<Edge> vec;
};
struct SegmentTree {
    TreeNode tr[maxn<<2];
    void insert(int cur, int l, int r, int L, int R, Edge e) {
        if (L <= l && R >= r) {
            tr[cur].vec.push_back(e);
            return;
        }
        int mid = (l+r) >> 1;
        if (L <= mid) insert(cur<<1, l, mid, L, R, e);
        if (R > mid) insert(cur<<1|1, mid+1, r, L, R, e);
    }
    void dfs(int cur, int l, int r) {
        for (Edge& e : tr[cur].vec) {
            dsu.unions(e.u, e.v);
        }
        if (l == r) {
            ans[l] = res;
        } else {
            int mid = (l+r) >> 1;
            dfs(cur<<1, l, mid);
            dfs(cur<<1|1, mid+1, r);
        }
        for (Edge& e : tr[cur].vec) {
            dsu.cancel();
        }
    }
} tr;


int n;
map<pii, int> mp;
int main() {
    cin >> n;
    dsu.init();
    for (int t = 1; t <= n; t++) {
        int u, v; cin >> u >> v; v += M;
        if (mp.count({u,v})) {
            int pt = mp[{u,v}];
            mp.erase({u,v});
            tr.insert(1, 1, n, pt, t-1, {u,v});
        } else {
            mp[{u,v}] = t;
        }
    }
    for (auto itr : mp) {
        tr.insert(1, 1, n, itr.second, n, {itr.first.first, itr.first.second});
    }
    tr.dfs(1, 1, n);
    for (int i = 1; i <= n; i++) cout << ans[i] << " ";
    cout << "\n";
}

例4 CF1814F. Communication Towers

题意

有 $n$ 个点,$m$ 条双向的边。

点 $i$ 只能接受频率在 $[l_i, r_i]$ 之间的信号。

如果存在一个信号频率 $x$,使得存在一条从 $1$ 到 $x$ 的路径,并且路径上的所有点都接受频率 $x$ 的信号,那么说明 $x$ 是可达的。

问有多少个节点 $x$ 是可达的。

其中,$n \leq 2 \times 10^5, m \leq 4 \times 10^5, 1 \leq l_i \leq r_i \leq 2 \times 10^5$。

题解

一眼线段树分治,把信号频率当作时间,这样可以得到每一个频率信号对应的 $1$ 的联通块。

但问题是,在得到一个频率信号的联通块以后,不能直接暴力统计答案,否则是 $O(n^2)$ 的。

所以考虑在并查集上打标记。

具体来说:当我们 dfs 到线段树的叶子节点时,就可以得到了 $1$ 所在的联通块,这时在并查集中,给 $1$ 的联通块的parent打上标记。


问题来了:什么时候下传标记?如何下传标记?毕竟我们没有记录每个节点的 children。

我们在回溯时,下传标记,刚好回溯时,是将 $u,v$ 断开,假设 $u$ 为回溯前的 parent,那么此时我们就可以让 flag[v] += flag[u](注意不要清空 $u$ 的标记,因为之后要下传给别的child)。

那出现了另外一个问题:假设我给节点 $1$ 打上了标记,此时的联通块是 [1,2,3],但之后 $4$ 连上了 $1$,联通块变成 [1,2,3,4] 了,在之后 $4$ 与 $1$ 断开时,$1$ 不是会将标记下传给 $4$ 吗?这样就不对了。

确实,所以我们在 unions(u, v) 时,要将 flag[v] -= flag[u],因为之后回溯时肯定有 flag[v] += flag[u],所以我们先把这个影响减去,就可以了。

总结一下:

  1. 在叶子节点给 $1$ 所在联通块的根打标记。
  2. 在回溯时,下传标记。
  3. unions 时,减去标记。

最后,答案就是所有被标记的节点。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 4e5+500;

struct State {
    int u, v, szu, szv;
} st[maxm];  // 注意这里是 maxm,应该是边的数量
int n, m, k;
struct DSU {
    int par[maxn], sz[maxn], flag[maxn], tail = 0;
    inline void init() {
        for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1;
        tail = 0;
    }
    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);  // 无路径压缩
    }
    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u] < sz[v]) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};   // 无论是否 union 成功都要push到 stack 里
        if (u == v) return;
        par[v] = u;
        sz[u] += sz[v];
        flag[v] -= flag[u];
    }
    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            par[v] = v;
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            tail--;
            if (u != v)  // 保证不会重复
                flag[v] += flag[u];
        }
    }
} dsu;
struct Edge {
    int u, v;
};
struct TreeNode {
    vector<Edge> vec;
};
struct SegmentTree {
    TreeNode tr[maxn<<2];
    void insert(int cur, int l, int r, int L, int R, Edge e) {
        if (L <= l && R >= r) {
            tr[cur].vec.push_back(e);
            return;
        }
        int mid = (l+r) >> 1;
        if (L <= mid) insert(cur<<1, l, mid, L, R, e);
        if (R > mid) insert(cur<<1|1, mid+1, r, L, R, e);
    }
    void dfs(int cur, int l, int r) {
        for (Edge& e : tr[cur].vec) {
            dsu.unions(e.u, e.v);
        }
        
        if (l == r) {
            dsu.flag[dsu.finds(1)]++;
        }
        else {
            int mid = (l+r) >> 1;
            dfs(cur<<1, l, mid);
            dfs(cur<<1|1, mid+1, r);
        }

        for (Edge& e : tr[cur].vec) {
            dsu.cancel();
        }
    }
} tr;


int l[maxn], r[maxn];
int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        cin >> l[i] >> r[i];
    }
    int N = 2e5;
    for (int i = 1; i <= m; i++) {
        int u, v; cin >> u >> v;
        int L = max(l[u], l[v]), R = min(r[u], r[v]);
        if (L <= R) {
            tr.insert(1, 1, N, L, R, {u,v});
        }
    }
    dsu.init();
    tr.dfs(1, 1, N);
    for (int i = 1; i <= n; i++) {
        if (dsu.flag[i]) cout << i << " ";
    }
    cout << "\n";
}