点分治
Contents
介绍
点分治用于处理 树上的路径问题。
点分治的主要思想是,对于一棵子树,子树内的所有路径只有两种情况:
- 不经过根节点
- 经过根节点
对于第一种,我们可以在处理其他子树的时候再讨论。
对于第二种,注意到一个经过根节点的路径,可以被拆分成从根节点出发的一条路径,合并上另外一条从根节点出发的路径。
所以点分治的核心思想就是对于每一个子树,都只考虑从根节点出发的路径,这些路径有 $O(n)$ 条。
但极端情况下,比如一条链,这样的复杂度可能来到 $O(n^2)$,所以在寻找一个子树的根时,应该将这个子树的重心作为根,这样递归的时候深度最多就是 $O(\log n)$。
总体上来说复杂度就是 $O(n \log n)$。
模版
int n, m;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
bool ans[maxn], has_dis[maxm]; // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
// 计算 u 为根,所有以 u 出发的路径带来的贡献
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
// 将子树 v 的贡献加进 ans
for (int p = hd; p <= tail; p++) { // 遍历子树 v 的节点
for (int k = 1; k <= m; k++) { // 遍历每一个询问
int q_dis = queries[k];
if (q_dis >= q[p]) {
ans[k] |= has_dis[q_dis - q[p]];
}
}
}
// 考虑完子树 v 以后,将子树 v 的信息储存进去
for (int p = hd; p <= tail; p++) {
has_dis[q[p]] = 1;
tmp.push_back(q[p]);
}
}
for (int d : tmp) has_dis[d] = 0; // 清空 u 的子树信息
tmp.clear();
}
// 分治 u
void solve(int u) {
vis[u] = 1; dis[u] = 0;
has_dis[0] = 1; // 初始情况
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
for (int i = 1; i <= m; i++) {
int k; cin >> k;
queries[i] = k;
}
find_centroid(1, 0);
solve(rt);
for (int i = 1; i <= m; i++) {
cout << (ans[i] ? "AYE" : "NAY") << "\n";
}
}
在模版中,主要修改的部分就是 calc(u)
函数,这个函数代表着题目要求计算什么样的路径。
例题
例1 洛谷P3806 点分治1
题意
给定一棵有 $n$ 个点的树,边有权值。
有 $m$ 次询问,每次询问树上距离为 $k$ 的点对是否存在。
其中,$n \leq 10^4, m \leq 100, 1 \leq k \leq 10^7, 1 \leq w \leq 10^4$。
题解
注意到 $k \leq 10^7$,所以可以开一个数组来储存长度为 $x$ 的路径(从根出发)是否存在。
然后就是点分治的模版了,有几个点可能需要注意:
- 点分治的
calc(u)
过程里,枚举了每个子树 $v$,枚举一个 $v$ 得到子树信息以后,先将信息贡献给ans[]
,然后才储存进当前子树内。这是为了防止出现非法的情况,比如一个路径,两个端点都在同一个子树 $v$ 内,这也类似于树形 DP 的思想。 - 使用了一个队列
q
来储存子树 $v$ 的信息。 - 使用了一个
vector<> tmp
来储存整个 $u$ 子树的节点,calc()
结束以后用来清空信息,避免 memset 导致复杂度变成 $O(n^2)$。 - 记得根节点 $u$ 的信息在一开始要先储存进去(或者后续贡献
ans[]
时单独考虑)。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 1e8+5;
int n, m;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
bool ans[maxn], has_dis[maxm]; // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
for (int p = hd; p <= tail; p++) {
for (int k = 1; k <= m; k++) { // 遍历询问
int q_dis = queries[k];
if (q_dis >= q[p]) {
ans[k] |= has_dis[q_dis - q[p]];
}
}
}
for (int p = hd; p <= tail; p++) {
has_dis[q[p]] = 1;
tmp.push_back(q[p]);
}
}
for (int d : tmp) has_dis[d] = 0; // 清空 u 的子树信息
tmp.clear();
}
void solve(int u) {
vis[u] = 1; dis[u] = 0;
has_dis[0] = 1; // 初始情况
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
cin >> n >> m;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
for (int i = 1; i <= m; i++) {
int k; cin >> k;
queries[i] = k;
}
find_centroid(1, 0);
solve(rt);
for (int i = 1; i <= m; i++) {
cout << (ans[i] ? "AYE" : "NAY") << "\n";
}
}
例2 洛谷P4178 Tree
题意
给定一棵有 $n$ 个点的树,边有权值。
求出树上两点距离小于等于 $k$ 的点对数量。
其中,$n \leq 4 \times 10^4, w \in [0, 10^3], k \in [0, 2 \times 10^4]$。
题解
小于等于 $k$ 的话,用树状数组维护一下就可以了,剩下的和上一题几乎没区别。
代码
#include <bits/stdc++.h>
using namespace std;
struct BIT {
int n, tr[maxn];
inline int lowbit(int x) { return x & -x; }
void update(int p, int val) {
while (p <= n) {
tr[p] += val;
p += lowbit(p);
}
}
// return sum[1...p]
int query(int p) {
int ans = 0;
while (p > 0) {
ans += tr[p];
p -= lowbit(p);
}
return ans;
}
} tr;
int n, k;
struct Edge {
int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];
ll ans = 0;
void addEdge(int u, int v, int w) {
Edge e = {v, head[u], w};
head[u] = ecnt;
edges[ecnt++] = e;
}
int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
sz[u] = 1;
int mx = 0;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
find_centroid(v, u);
sz[u] += sz[v];
mx = max(mx, sz[v]);
}
mx = max(mx, cursz - sz[u]);
if (mx <= cursz / 2) rt = u;
}
void get_cursz(int u, int p) {
sz[u] = 1;
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v] || v == p) continue;
get_cursz(v, u);
sz[u] += sz[v];
}
}
void getdis(int u, int p) {
q[++tail] = dis[u];
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (v == p || vis[v]) continue;
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
}
}
void calc(int u) {
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
hd = 1, tail = 0; // 清空 v 的子树信息
dis[v] = dis[u] + edges[e].w;
getdis(v, u);
for (int p = hd; p <= tail; p++) {
int d = q[p];
// d + x <= k 说明 x <= k - d
if (d <= k) {
ans += tr.query(k - d);
ans++; // 到根节点
}
}
for (int p = hd; p <= tail; p++) {
if (q[p] <= k) {
tr.update(q[p], 1);
tmp.push_back(q[p]);
}
}
}
// printf("u = %d, ans = %lld\n",u,ans);
for (int d : tmp) tr.update(d, -1);
tmp.clear();
}
void solve(int u) {
vis[u] = 1; dis[u] = 0;
calc(u);
// 处理答案
for (int e = head[u]; e; e = edges[e].nxt) {
int v = edges[e].to;
if (vis[v]) continue;
get_cursz(v, 0);
cursz = sz[v];
find_centroid(v, u);
solve(rt); // 子树
}
}
int main() {
tr.n = 2e4;
cin >> n;
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
addEdge(u, v, w); addEdge(v, u, w);
}
cin >> k;
find_centroid(1, 0);
solve(rt);
cout << ans << "\n";
}