介绍

点分治用于处理 树上的路径问题

点分治的主要思想是,对于一棵子树,子树内的所有路径只有两种情况:

  1. 不经过根节点
  2. 经过根节点

对于第一种,我们可以在处理其他子树的时候再讨论。

对于第二种,注意到一个经过根节点的路径,可以被拆分成从根节点出发的一条路径,合并上另外一条从根节点出发的路径。

所以点分治的核心思想就是对于每一个子树,都只考虑从根节点出发的路径,这些路径有 $O(n)$ 条。

但极端情况下,比如一条链,这样的复杂度可能来到 $O(n^2)$,所以在寻找一个子树的根时,应该将这个子树的重心作为根,这样递归的时候深度最多就是 $O(\log n)$。

总体上来说复杂度就是 $O(n \log n)$。

模版
int n, m;
struct Edge {
    int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];  
bool ans[maxn], has_dis[maxm];  // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
    Edge e = {v, head[u], w};
    head[u] = ecnt;
    edges[ecnt++] = e;
}

int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
    sz[u] = 1;
    int mx = 0;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        find_centroid(v, u);
        sz[u] += sz[v];
        mx = max(mx, sz[v]);
    }
    mx = max(mx, cursz - sz[u]);
    if (mx <= cursz / 2) rt = u;
}

void get_cursz(int u, int p) {
    sz[u] = 1;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v] || v == p) continue;
        get_cursz(v, u);
        sz[u] += sz[v];
    }
}

void getdis(int u, int p) {
    q[++tail] = dis[u];
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
    }
}

// 计算 u 为根,所有以 u 出发的路径带来的贡献
void calc(int u) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        hd = 1, tail = 0;  // 清空 v 的子树信息
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
        // 将子树 v 的贡献加进 ans
        for (int p = hd; p <= tail; p++) {   // 遍历子树 v 的节点
            for (int k = 1; k <= m; k++) {  // 遍历每一个询问
                int q_dis = queries[k];
                if (q_dis >= q[p]) {
                    ans[k] |= has_dis[q_dis - q[p]];
                }
            }
        }
        // 考虑完子树 v 以后,将子树 v 的信息储存进去
        for (int p = hd; p <= tail; p++) {
            has_dis[q[p]] = 1;
            tmp.push_back(q[p]);
        }
    }
    for (int d : tmp) has_dis[d] = 0;  // 清空 u 的子树信息
    tmp.clear();
}

// 分治 u
void solve(int u) {
    vis[u] = 1; dis[u] = 0; 
    has_dis[0] = 1;  // 初始情况
    calc(u);
    // 处理答案
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        get_cursz(v, 0);
        cursz = sz[v];
        find_centroid(v, u);
        solve(rt);  // 子树
    }
}


int main() {
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        addEdge(u, v, w); addEdge(v, u, w);
    }
    for (int i = 1; i <= m; i++) {
        int k; cin >> k;
        queries[i] = k;
    }
    
    find_centroid(1, 0);
    solve(rt);

    for (int i = 1; i <= m; i++) {
        cout << (ans[i] ? "AYE" : "NAY") << "\n";
    }
}

在模版中,主要修改的部分就是 calc(u) 函数,这个函数代表着题目要求计算什么样的路径。

例题

例1 洛谷P3806 点分治1

题意

给定一棵有 $n$ 个点的树,边有权值。

有 $m$ 次询问,每次询问树上距离为 $k$ 的点对是否存在。

其中,$n \leq 10^4, m \leq 100, 1 \leq k \leq 10^7, 1 \leq w \leq 10^4$。

题解

注意到 $k \leq 10^7$,所以可以开一个数组来储存长度为 $x$ 的路径(从根出发)是否存在。

然后就是点分治的模版了,有几个点可能需要注意:

  1. 点分治的 calc(u) 过程里,枚举了每个子树 $v$,枚举一个 $v$ 得到子树信息以后,先将信息贡献给 ans[],然后才储存进当前子树内。这是为了防止出现非法的情况,比如一个路径,两个端点都在同一个子树 $v$ 内,这也类似于树形 DP 的思想。
  2. 使用了一个队列 q 来储存子树 $v$ 的信息。
  3. 使用了一个 vector<> tmp 来储存整个 $u$ 子树的节点,calc() 结束以后用来清空信息,避免 memset 导致复杂度变成 $O(n^2)$。
  4. 记得根节点 $u$ 的信息在一开始要先储存进去(或者后续贡献 ans[] 时单独考虑)。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 1e8+5;

int n, m;
struct Edge {
    int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];  
bool ans[maxn], has_dis[maxm];  // has_dis[i]: 当前子树中存在到根距离为i的节点
void addEdge(int u, int v, int w) {
    Edge e = {v, head[u], w};
    head[u] = ecnt;
    edges[ecnt++] = e;
}

int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
    sz[u] = 1;
    int mx = 0;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        find_centroid(v, u);
        sz[u] += sz[v];
        mx = max(mx, sz[v]);
    }
    mx = max(mx, cursz - sz[u]);
    if (mx <= cursz / 2) rt = u;
}

void get_cursz(int u, int p) {
    sz[u] = 1;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v] || v == p) continue;
        get_cursz(v, u);
        sz[u] += sz[v];
    }
}

void getdis(int u, int p) {
    q[++tail] = dis[u];
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
    }
}

void calc(int u) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        hd = 1, tail = 0;  // 清空 v 的子树信息
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
        for (int p = hd; p <= tail; p++) {
            for (int k = 1; k <= m; k++) {  // 遍历询问
                int q_dis = queries[k];
                if (q_dis >= q[p]) {
                    ans[k] |= has_dis[q_dis - q[p]];
                }
            }
        }

        for (int p = hd; p <= tail; p++) {
            has_dis[q[p]] = 1;
            tmp.push_back(q[p]);
        }
    }
    for (int d : tmp) has_dis[d] = 0;  // 清空 u 的子树信息
    tmp.clear();
}

void solve(int u) {
    vis[u] = 1; dis[u] = 0; 
    has_dis[0] = 1;  // 初始情况
    calc(u);
    // 处理答案
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        get_cursz(v, 0);
        cursz = sz[v];
        find_centroid(v, u);
        solve(rt);  // 子树
    }
}


int main() {
    cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        addEdge(u, v, w); addEdge(v, u, w);
    }
    for (int i = 1; i <= m; i++) {
        int k; cin >> k;
        queries[i] = k;
    }
    
    find_centroid(1, 0);
    solve(rt);

    for (int i = 1; i <= m; i++) {
        cout << (ans[i] ? "AYE" : "NAY") << "\n";
    }
}

例2 洛谷P4178 Tree

题意

给定一棵有 $n$ 个点的树,边有权值。

求出树上两点距离小于等于 $k$ 的点对数量。

其中,$n \leq 4 \times 10^4, w \in [0, 10^3], k \in [0, 2 \times 10^4]$。

题解

小于等于 $k$ 的话,用树状数组维护一下就可以了,剩下的和上一题几乎没区别。

代码
#include <bits/stdc++.h>
using namespace std;
struct BIT {
    int n, tr[maxn];
    inline int lowbit(int x) { return x & -x; }
    void update(int p, int val) {
        while (p <= n) {
            tr[p] += val;
            p += lowbit(p);
        }
    }
    // return sum[1...p]
    int query(int p) {
        int ans = 0;
        while (p > 0) {
            ans += tr[p];
            p -= lowbit(p);
        }
        return ans;
    }
} tr;

int n, k;
struct Edge {
    int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, queries[maxn];  
ll ans = 0;

void addEdge(int u, int v, int w) {
    Edge e = {v, head[u], w};
    head[u] = ecnt;
    edges[ecnt++] = e;
}

int sz[maxn], dis[maxn], q[maxn], hd = 1, tail = 0;
bool vis[maxn];
vector<int> tmp;
int cursz, rt;
void find_centroid(int u, int p) {
    sz[u] = 1;
    int mx = 0;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        find_centroid(v, u);
        sz[u] += sz[v];
        mx = max(mx, sz[v]);
    }
    mx = max(mx, cursz - sz[u]);
    if (mx <= cursz / 2) rt = u;
}

void get_cursz(int u, int p) {
    sz[u] = 1;
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v] || v == p) continue;
        get_cursz(v, u);
        sz[u] += sz[v];
    }
}

void getdis(int u, int p) {
    q[++tail] = dis[u];
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (v == p || vis[v]) continue;
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
    }
}

void calc(int u) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        hd = 1, tail = 0;  // 清空 v 的子树信息
        dis[v] = dis[u] + edges[e].w;
        getdis(v, u);
        for (int p = hd; p <= tail; p++) {
            int d = q[p];
            // d + x <= k 说明 x <= k - d

            if (d <= k) {
                ans += tr.query(k - d);
                ans++;  // 到根节点
            }
        }

        for (int p = hd; p <= tail; p++) {
            if (q[p] <= k) {
                tr.update(q[p], 1);
                tmp.push_back(q[p]);
            }
        }
    }
    // printf("u = %d, ans = %lld\n",u,ans);
    for (int d : tmp) tr.update(d, -1);
    tmp.clear();
}

void solve(int u) {
    vis[u] = 1; dis[u] = 0; 
    calc(u);
    // 处理答案
    for (int e = head[u]; e; e = edges[e].nxt) {
        int v = edges[e].to;
        if (vis[v]) continue;
        get_cursz(v, 0);
        cursz = sz[v];
        find_centroid(v, u);
        solve(rt);  // 子树
    }
}


int main() {
    tr.n = 2e4;
    cin >> n;
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        addEdge(u, v, w); addEdge(v, u, w);
    }
    cin >> k;
    
    find_centroid(1, 0);
    solve(rt);
    cout << ans << "\n";
}