介绍

点分治用于处理 树上的路径问题

点分治的主要思想是,对于一棵子树,子树内的所有路径只有两种情况:

  1. 不经过根节点
  2. 经过根节点

对于第一种,我们可以在处理其他子树的时候再讨论。

对于第二种,注意到一个经过根节点的路径,可以被拆分成从根节点出发的一条路径,合并上另外一条从根节点出发的路径。

所以点分治的核心思想就是对于每一个子树,都只考虑从根节点出发的路径,这些路径有 O(n) 条。

但极端情况下,比如一条链,这样的复杂度可能来到 O(n2),所以在寻找一个子树的根时,应该将这个子树的重心作为根,这样递归的时候深度最多就是 O(logn)

总体上来说复杂度就是 O(nlogn)

模版
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
int n, m;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
bool ans[maxn], has_dis[maxm]; // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
// 计算 u 为根,所有以 u 出发的路径带来的贡献
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
// 将子树 v 的贡献加进 ans
for (int p = hd; p <= tail; p++) { // 遍历子树 v 的节点
for (int k = 1; k <= m; k++) { // 遍历每一个询问
int q_dis = queries[k];
if (q_dis >= q[p]) {
ans[k] |= has_dis[q_dis - q[p]];
}
}
}
// 考虑完子树 v 以后,将子树 v 的信息储存进去
for (int p = hd; p <= tail; p++) {
has_dis[q[p]] = 1;
tmp.push_back(q[p]);
}
}
for (int d : tmp) has_dis[d] = 0; // 清空 u 的子树信息
tmp.clear();
}
// 分治 u
void solve(int u) {
vis[u] = 1; dis[u] = 0;
has_dis[0] = 1; // 初始情况
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
for (int i = 1; i <= m; i++) {
int k; cin >> k;
queries[i] = k;
}
find_centroid(1, 0);
solve(rt);
for (int i = 1; i <= m; i++) {
cout << (ans[i] ? "AYE" : "NAY") << "\n";
}
}

在模版中,主要修改的部分就是 calc(u) 函数,这个函数代表着题目要求计算什么样的路径。

例题

例1 洛谷P3806 点分治1

题意

给定一棵有 n 个点的树,边有权值。

m 次询问,每次询问树上距离为 k 的点对是否存在。

其中,n104,m100,1k107,1w104

题解

注意到 k107,所以可以开一个数组来储存长度为 x 的路径(从根出发)是否存在。

然后就是点分治的模版了,有几个点可能需要注意:

  1. 点分治的 calc(u) 过程里,枚举了每个子树 v,枚举一个 v 得到子树信息以后,先将信息贡献给 ans[],然后才储存进当前子树内。这是为了防止出现非法的情况,比如一个路径,两个端点都在同一个子树 v 内,这也类似于树形 DP 的思想。
  2. 使用了一个队列 q 来储存子树 v 的信息。
  3. 使用了一个 vector<> tmp 来储存整个 u 子树的节点,calc() 结束以后用来清空信息,避免 memset 导致复杂度变成 O(n2)
  4. 记得根节点 u 的信息在一开始要先储存进去(或者后续贡献 ans[] 时单独考虑)。
代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 1e8+5;
int n, m;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
bool ans[maxn], has_dis[maxm]; // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
for (int p = hd; p <= tail; p++) {
for (int k = 1; k <= m; k++) { // 遍历询问
int q_dis = queries[k];
if (q_dis >= q[p]) {
ans[k] |= has_dis[q_dis - q[p]];
}
}
}
for (int p = hd; p <= tail; p++) {
has_dis[q[p]] = 1;
tmp.push_back(q[p]);
}
}
for (int d : tmp) has_dis[d] = 0; // 清空 u 的子树信息
tmp.clear();
}
void solve(int u) {
vis[u] = 1; dis[u] = 0;
has_dis[0] = 1; // 初始情况
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
for (int i = 1; i <= m; i++) {
int k; cin >> k;
queries[i] = k;
}
find_centroid(1, 0);
solve(rt);
for (int i = 1; i <= m; i++) {
cout << (ans[i] ? "AYE" : "NAY") << "\n";
}
}

例2 洛谷P4178 Tree

题意

给定一棵有 n 个点的树,边有权值。

求出树上两点距离小于等于 k 的点对数量。

其中,n4×104,w[0,103],k[0,2×104]

题解

小于等于 k 的话,用树状数组维护一下就可以了,剩下的和上一题几乎没区别。

代码
copy
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <bits/stdc++.h>
using namespace std;
struct BIT {
int n, tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= n) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
int query(int p) {
int ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr;
int n, k;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
ll ans = 0;
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
for (int p = hd; p <= tail; p++) {
int d = q[p];
// d + x <= k 说明 x <= k - d
if (d <= k) {
ans += tr.query(k - d);
ans++; // 到根节点
}
}
for (int p = hd; p <= tail; p++) {
if (q[p] <= k) {
tr.update(q[p], 1);
tmp.push_back(q[p]);
}
}
}
// printf("u = %d, ans = %lld\n",u,ans);
for (int d : tmp) tr.update(d, -1);
tmp.clear();
}
void solve(int u) {
vis[u] = 1; dis[u] = 0;
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
tr.n = 2e4;
cin >> n;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
cin >> k;
find_centroid(1, 0);
solve(rt);
cout << ans << "\n";
}