并查集
Contents
介绍
这篇博客主要介绍一些并查集的高级应用。
目前我所知的并查集应用有:
- 权值并查集
- 可撤销并查集
- 可持久化并查集(还没学)
权值并查集
略,目前我所知的方法就只有开多倍空间,维护不同元素之间的关系,经典例题见 NOI2001 食物链。
可撤销并查集
意思如其名,可以将之前合并的操作撤销。
思想并不难,只要拿一个额外的 stack
记录一下每次操作 更改的信息,然后回退的时候将 stack
内的信息 pop 出来,恢复一下当时的状态即可。
有几个需要注意的点:
- 并查集 不能路径压缩,否则会破坏结构。为了保证时间复杂度,需要使用启发式合并。
- 无论
unions(u,v)
是否成功,都要记录在栈中。因为我们回退时并不知道这一步合并当时是否成功了。
模版
ll ans = 0;
ll cal(ll x) {
return x*(x-1) / 2;
}
struct State {
int u, v, szu, szv;
} st[maxm]; // 注意这里是 maxm,应该是边的数量
struct DSU {
int par[maxn], sz[maxn], tail = 0;
inline void init() {
for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1;
tail = 0;
}
int finds(int u) {
if (par[u] == u) return u;
return finds(par[u]); // 无路径压缩
}
void unions(int u, int v) {
u = finds(u), v = finds(v);
if (sz[u] < sz[v]) swap(u,v); // sz[u] >= sz[v]
st[++tail] = {u, v, sz[u], sz[v]}; // 无论是否 union 成功都要push到 stack 里
if (u == v) return;
par[v] = u;
ans = ans - cal(sz[u]) - cal(sz[v]) + cal(sz[u] + sz[v]);
sz[u] += sz[v];
}
void cancel() {
if (tail > 0) {
int u = st[tail].u, v = st[tail].v;
par[v] = v;
if (sz[u] != st[tail].szu) {
ans = ans - cal(sz[u]) + cal(st[tail].szu) + cal(st[tail].szv);
}
sz[u] = st[tail].szu;
sz[v] = st[tail].szv;
tail--;
}
}
} dsu;
并查集维护仅删除的单向链表/树
并查集还可以用于维护一个只有删除操作的单向链表 / 树,维护的信息是一个节点的 “下一个xx” 相关信息。
例如在一棵树上,支持两种操作:
- 删除一个节点 $u$,将其所有 child 连到节点 $u$ 的parent上。
- 查询一个节点的parent。
就可以将它看作并查集问题,每个节点 $u$ 的并查集中,par[u]
记录的是:在它的上方,第一个还存在的节点是谁。
所以当一个节点还存在时,par[u] = u
,当它被删除时,就 unions(u, p)
,其中 $p$ 是 $u$ 在树上的 parent。
• 需要注意,unions(u,p)
时不能随便 unions
,我们需要保证 par[finds(u)] = finds(p)
。
例题
例1 [HNOI2016]最小公倍数
题意
给定一个 $n$ 个节点,$m$ 条边的无向图。每个边有权值,权值都可以被表示为 $2^a \times 3^b$ 的形式。
现在给出 $q$ 个询问,每次询问 $u ~ v ~ a ~ b$,我们需要回答在 $u,v$ 之间,是否存在一条路径使得路径上所有边权的 LCM 等于 $2^a \times 3^b$?
路径可以不为简单路径(可以经过重复的边)。
其中,$1 \leq n,q \leq 5 \times 10^4, 1 \leq m \leq 10^5, 0 \leq a_i,b_i \leq 10^9$。
题解
题目实际上就想要我们找,每条边的权值表示为 $(a_i,b_i)$,每次询问 $u ~ v ~ a ~ b$,回答是否存在路径使得路径上的 $\max \{a_i\} = a, \max \{b_i\} = b$?
我们先用最暴力的思路来想:
对于每次询问,我们只要把所有满足 $a_i \leq a, b_i \leq b$ 的边全都加进去,然后判断一下 $u,v$ 是否联通,并且这个联通块内 $\max \{a_i\} = a, \max \{b_i\} = b$ 是否成立即可。这可以用并查集实现。
接下来考虑优化。
对于这种拥有 $2$ 个条件的题目,一种套路是:
将第一种权值
sort
一下,然后分块。每个块内(或者多个块一起)对第二种权值
sort
。然后把所有询问放进对应的块内进行处理。
对于本题,我们先把所有的边按照 $a_i$ 的大小进行 sort
,然后进行分块。
对于每一条边 $(u,v,a,b)$,我们找到一个块 $B$ 满足,在 $[1,B-1]$ 这些块内,最大的 $a_i$ 都 $\leq a$,且这个 $B$ 尽可能大。
之后,我们对于每一个块进行处理。
假设我们现在在块 $B$ 内,那么我们将 $[1,B-1]$ 这些块看作一个整体,然后对这个整体,按照 第二种权值 $b_i$ 进行 sort
。
然后我们将块 $B$ 内的所有询问按照 $b_i$ 进行 sort
。
然后我们开始回答询问 $(u,v,a,b)$。
首先注意到此时,$[1,B-1]$ 这些块内的元素一定满足 $a_i \leq a$,所以不用关心前面这些块内的 $a_i$。
然后考虑 $[1,B-1]$ 这些块内的元素的 $b_i$,注意到它们此时是按照 $b_i$ sort
好的,而我们的询问也是按照 $b$ sort
的,所以我们可以直接用一个指针维护一下现在我们在 $[1,B-1]$ 这些块内,有哪些元素的 $b_i$ 是 $\leq b$ 的。
这样,$[1,B-1]$ 这些块就都处理好了,只剩下当前这个块 $B$ 了。
对于当前块 $B$,我们只要暴力向并查集内加入当前块内,所有满足 $a_i \leq a, b_i \leq b$ 的边即可。
在回答完这个询问后,把当前块 $B$ 的所有边都从并查集中撤销,然后开始回答下一个询问。
总时间复杂度:设块的大小为 $B$,那么总共有 $\frac{m}{B}$ 个块。
每一个块的处理:首先要对前面的块进行sort,所以是 $O(m\log m)$。
对于每一个询问,我们都要暴力遍历一个块内的所有边,总共有 $B$ 条边,每次加入操作需要 $O(\log m)$(因为没有路径压缩),所以是 $O(B \log m)$。
于是最终的复杂度就是:
$$T(n) = \frac{m}{B} m\log m + qB \log m$$
取 $B = \sqrt m$ 或者 $\frac{m}{\sqrt q}$ 之类的都可以过。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e4+5;
const int maxm = 1e5+5;
struct Query {
int u, v, a, b, be, id;
} query[maxn];
vector<Query> vec[maxn];
struct Edge {
int from, to, a, b;
} edges[maxm];
bool cmpA(const Edge& e1, const Edge& e2) {
return e1.a < e2.a;
}
bool cmpB(const Edge& e1, const Edge& e2) {
return e1.b < e2.b;
}
int ecnt = 0, n, m, B, blockcnt, Q;
void addEdge(int u, int v, int a, int b) {
Edge e = {u, v, a, b};
edges[++ecnt] = e;
}
struct State {
int u, v, szu, amaxu, bmaxu;
} st[maxm];
int tail = 0;
int par[maxn], sz[maxn], amax[maxn], bmax[maxn];
inline void init() {
for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1, amax[i] = bmax[i] = -1;
tail = 0;
}
int finds(int u) {
if (par[u] == u) return u;
return finds(par[u]);
}
void unions(int u, int v, int a, int b) {
u = finds(u), v = finds(v);
if (sz[u] < sz[v]) swap(u,v);
st[++tail] = {u, v, sz[u], amax[u], bmax[u]};
amax[u] = max(amax[u], a); // 注意不管是否成功,都要更改 amax, bmax
bmax[u] = max(bmax[u], b);
if (u == v) return;
par[v] = u;
sz[u] += sz[v];
amax[u] = max(amax[u], amax[v]);
bmax[u] = max(bmax[u], bmax[v]);
}
void cancel() {
if (tail > 0) {
int u = st[tail].u, v = st[tail].v;
par[v] = v;
sz[u] = st[tail].szu;
amax[u] = st[tail].amaxu;
bmax[u] = st[tail].bmaxu;
tail--;
}
}
bool ans[maxn];
int main() {
cin >> n >> m;
for (int i = 1; i <= m; i++) {
int u,v,a,b; cin >> u >> v >> a >> b;
addEdge(u,v,a,b);
}
sort(edges+1, edges+ecnt+1, cmpA);
B = sqrt(n); // block size
blockcnt = (ecnt + B - 1) / B;
cin >> Q;
// block 1: [1 ... B], block 2: [B+1, ... , 2B],
for (int i = 1; i <= Q; i++) {
cin >> query[i].u >> query[i].v >> query[i].a >> query[i].b;
query[i].id = i;
query[i].be = 1;
for (int b = 1; b <= blockcnt-1; b++) {
if (edges[b * B].a <= query[i].a) {
query[i].be = b+1;
}
}
vec[query[i].be].push_back(query[i]);
}
for (int b = 1; b <= blockcnt; b++) {
if (!vec[b].size()) continue; // 若这个块为空
sort(vec[b].begin(), vec[b].end(), [](auto a, auto b) {
return a.b < b.b; // 把所有询问按照 b 的大小 sort 一下
});
init();
if (b > 1) {
int R = (b-1) * B; // 前一个块的右端点
sort(edges+1, edges+R+1, cmpB);
// 开始处理询问
int j = 0;
for (Query que : vec[b]) {
int qa = que.a, qb = que.b;
// 先将前面的块满足条件的都加进去
while (j + 1 <= R && edges[j+1].b <= qb) {
unions(edges[j+1].from, edges[j+1].to, edges[j+1].a, edges[j+1].b);
j++;
}
// 然后开始处理当前块
int curt = tail; // 记录当前 stack 的 tail 位置
for (int i = R+1; i <= min(b*B, ecnt); i++) { // 当前块
if (edges[i].a <= qa && edges[i].b <= qb) {
unions(edges[i].from, edges[i].to, edges[i].a, edges[i].b);
}
}
int u = finds(que.u), v = finds(que.v);
if (u == v && amax[u] == qa && bmax[u] == qb) {
ans[que.id] = 1;
}
while (tail != curt) cancel(); // 退回当前块的所有操作
}
}
}
for (int i = 1; i <= Q; i++) {
cout << (ans[i] ? "Yes" : "No") << "\n";
}
}
注意事项
- 在
unions(u,v,a,b)
时,无论u,v
是否已经联通,都要更新a,b
。 - 可撤销并查集的栈只记录被更改的信息,其他的就不用记了。
例2 BZOJ4668 冷战
题意
给定 $n$ 个点 和 $m$ 次询问。
每次询问有两种格式:
$0 ~ u ~ v$:将 $u,v$ 用一条边链接起来。
$1 ~ u ~ v$:询问最早在加入哪条边以后,$u,v$ 在同一个联通块中。如果此时尚未联通,则输出 $0$。
所有询问强制在线。
其中,$n,m \leq 5 \times 10^5$。
题解
启发式合并并查集的题。
第一反应是每次链接的时候,新建一个 parent 节点储存这是第几条边,然后把 $u’, v'$ 作为它的两个 child。($u’,v'$ 是在 finds(u), finds(v)
得到的根节点)。类似于 kruskal 重构树的思路。
然后每次询问 $1 ~ u ~ v$ 的时候,直接找到 LCA 位置的那个节点即可。
这样是正确的,不过有个更优雅的写法:
注意到有效的链接只有 $(n-1)$ 次,普通的启发式合并刚好生成的就是一棵树,也不用新建节点了。直接启发式合并的时候,把当前是第几条边,记录在链接用的那一条树边上。
这样,$(n-1)$ 条树边,刚好对应的就是 $(n-1)$ 次启发式合并。
每次询问 $1 ~ u ~ v$ 的时候,直接暴力往上跳到 LCA,然后取边上的最大值即可。(就相当于 $u,v$ 之间链上的边权最大值)。
由于启发式合并的性质,暴力跳 LCA 的时间是 $O(\log n)$。
所以复杂度就是 $O(n \log n)$。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
int n,m, dep[maxn], par[maxn], sz[maxn], val[maxn];
int finds(int u) {
if (par[u] == u) {
dep[u] = 0;
return u;
}
int res = finds(par[u]);
dep[u] = dep[par[u]] + 1;
return res;
}
void unions(int u, int v, int id) {
u = finds(u), v = finds(v);
if (u == v) return;
if (sz[u] < sz[v]) swap(u,v);
sz[u] += sz[v];
par[v] = u;
val[v] = id; // 链接 (v,u) 的边权为 id
}
// 寻找 u,v 链上最大的值
int find_max(int u, int v) {
if (finds(u) != finds(v)) return 0;
if (dep[u] < dep[v]) swap(u,v);
int res = 0;
while (dep[u] > dep[v]) {
res = max(res, val[u]);
u = par[u];
}
if (u == v) return res;
while (u != v) {
res = max({res, val[u], val[v]});
u = par[u];
v = par[v];
}
return res;
}
int main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) {
par[i] = i;
sz[i] = 1;
}
int last = 0, id = 0;
while (m--) {
int op,u,v; cin >> op >> u >> v;
u ^= last, v ^= last;
if (op == 0) {
unions(u, v, ++id);
} else {
int res = find_max(u, v);
cout << res << "\n";
last = res;
}
}
}
例3 Do Not Try This Problem
题意
给定一个长度为 $n$ 的 string $s$,给定 $q$ 个操作,每次操作的格式是 $i~a~k~c$,将从index $i$ 开始,将 $s_i,s_{i+a},…,s_{i+ka}$ 替换为字符 $c$。
输出所有操作结束后的string。
其中,$n, q \leq 10^5, i \in [1,n], a \in [1, n), k \in [0,n), i+ka \leq n$。
题解
由于询问离线,所以可以从后向前处理。因为后面的操作优先级更高,当一个后面的操作更改了某个位置,说明之后再也不用改这个位置了。
然后考虑分块,对于 $a > \sqrt n$ 的操作,直接暴力修改即可。
对于 $a \leq \sqrt n$ 的操作,我们利用 并查集维护仅带删除的单向链表。
由于一个位置被更改以后,就再也不用改这个位置了,把它当作被修改后,直接删除这个位置,那么我们维护 $\sqrt n$ 种链表,第 $j$ 种链表将会维护 $j$ 个跨度为 $j$ 的链表,例如 $j=3$,那么维护的就是 $1 \rightarrow 4 \rightarrow 7 \rightarrow …$ 和 $2 \rightarrow 5 \rightarrow 8 \rightarrow…$ 和 $3 \rightarrow 6 \rightarrow 9 \rightarrow …$。
而每一种链表就是一个并查集,每一个元素的parent是它 下一个存在的元素。
在删除一个元素时,更新其 parent 即可。
按理说,删除一个元素时,应该更新其所有的parent,即更新 $\sqrt n$ 种链表,如:
void del(int p, char c) { // 删除位置p
if (vis[p]) return;
vis[p] = 1;
s[p] = c;
for (int i = 1; i < m; i++) {
uf[i].unions(p, min(p+i, n+1));
}
}
但是这样每次删除一个元素都要更新,虽然复杂度还是 $O(n \log n)$ ,但常数略高会T。
我们改成在询问 $< \sqrt n$ 的情况下才更新,并且只更新对应的 $a$,详情见代码。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+3;
struct UnionFind {
int par[maxn];
int finds(int u) {
if (par[u] == u) return u;
return par[u] = finds(par[u]);
}
void unions(int u, int v) {
u = finds(u), v = finds(v);
if (u < v) swap(u, v);
par[v] = u;
}
} uf[320];
struct Query {
int i, a, k;
char c;
} q[maxn];
int n, m, Q;
string s;
bool vis[maxn];
void del(int p, char c) { // 删除位置p
if (vis[p]) return;
vis[p] = 1;
s[p] = c;
}
int main() {
fastio;
cin >> s;
n = s.size();
s = "#" + s;
m = sqrt(n);
for (int i = 1; i < m; i++) {
for (int j = 1; j <= n+1; j++) uf[i].par[j] = j;
}
cin >> Q;
for (int i = 1; i <= Q; i++) cin >> q[i].i >> q[i].a >> q[i].k >> q[i].c;
for (int i = Q; i >= 1; i--) {
auto [st, a, k, c] = q[i];
if (a >= m) {
for (int j = st; j <= st+a*k; j += a) {
del(j, c);
}
} else {
int j = st;
while (j <= st+a*k) {
del(j, c);
uf[a].unions(j, min(j+a, n+1));
j = uf[a].finds(j);
}
}
}
for (int i = 1; i <= n; i++) cout << s[i];
cout << "\n";
}
例4 TJOI2016 树
题意
给定一棵树,根为 $1$,现在有两个操作:
C u
:给节点 $u$ 打标记(一个节点可能被打上多次标记)。
Q u
:询问节点 $u$ 打了标记的最近祖先(算上它自己)。
一开始 $1$ 被标记了,其他所有节点无标记。
其中,$n,q \leq 10^5$。
题解
离线操作,将问题反过来看:
先给所有节点打上标记,然后利用并查集来维护支持删点的树。并查集上维护每一个节点上方的第一个打了标记的点。
然后将所有询问从后往前处理,遇到一个标记操作,就将这个节点的标记次数 -1(因为一个节点可以被标记多次),当一个节点的标记次数归0时,将这个节点从树中删除。
这样就可以 $O(n)$ 解决这个问题了。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
pair<char, int> q[maxn];
vector<int> adj[maxn];
int f[maxn];
int par[maxn], dep[maxn];
struct UnionFind {
int par[maxn];
int finds(int u) {
if (par[u] == u) return u;
return par[u] = finds(par[u]);
}
void unions(int u, int v) {
u = finds(u), v = finds(v);
if (dep[u] > dep[v]) swap(u, v);
par[v] = u;
}
} uf;
void dfs(int u, int p) {
par[u] = p;
dep[u] = dep[p] + 1;
if (!f[u]) {
uf.unions(u, p);
}
for (int v : adj[u]) {
if (v == p) continue;
dfs(v, u);
}
}
int main() {
int n, Q; cin >> n >> Q;
for (int i = 1; i < n; i++) {
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
f[1] = 1;
for (int i = 1; i <= Q; i++) {
cin >> q[i].first >> q[i].second;
if (q[i].first == 'C') {
f[q[i].second]++;
}
}
for (int i = 1; i <= n; i++) uf.par[i] = i;
dfs(1, 0);
vector<int> ans;
for (int i = Q; i >= 1; i--) {
int u = q[i].second;
if (q[i].first == 'C') {
f[u]--;
if (!f[u])
uf.unions(u, par[u]);
} else {
ans.push_back(uf.finds(u));
}
}
reverse(ans.begin(), ans.end());
for (int i : ans) cout << i << "\n";
}