介绍

01-Trie 指将整数拆成二进制,然后将二进制以字符串的形式储存在 Trie 中。

img

Trie可以解决以下问题:

  1. 给定一些数 $a_1,a_2,…,a_n$,给定 $x$,求 $a_i$,使得 $a_i \text{ xor } x$ 最大。

  2. 维护异或和:给定一些数,支持全体加一,插入数字,删除数字,求全体异或和等操作。

Trie可以将数字从低位到高位储存,也可以反过来,根据具体题目而定。

为了方便,我们在写 01Trie 的时候都会把每个数补成同样的位数。

• 注意 id = 1 的是 Root,所以 id 要从 $1$ 开始。

维护最大XOR

指第一个例子,这里我们的 01Trie 将会 从高位到低位 储存。

在给定一个查询 $x$ 时,我们从 $x$ 的高位开始看,设当前到了第 $i$ 位,那么我们看第 $i$ 位的bit $a_i$,然后判断一下第 $i$ 位的值为 a[i] ^ 1 的数字是否存在即可,如果存在,往那个方向走,否则往另外一个方向走。

板子
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;

int id = 1;  // 注意,从 1 开始
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt++;
    }
}
// 求 trie中的一个数 a_i, 使得 a_i ^ x 最大
int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[c].child[k]) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
        }
    }
    return res;
}

维护XOR和

指第二个例子,这里我们的 01Trie 从低位到高位储存。

要维护异或和,我们只需要知道每一位上 $0$ 和 $1$ 个数的奇偶性即可。

对于每一个节点,我们记录三个量:

  1. 两个子节点的编号
  2. w:指当前这个节点,到它的parent这条边 $(p,u)$ 上,被经过的次数。每有一个数字在插入过程中经过这个边,这个 w 就会加一。
  3. val:指以当前节点为根,它子树内包含的所有数字的 XOR和

所以维护信息的时候就有:

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

解释: 注意是从低位到高位储存的,所以子树内储存的只是数字的二进制的一部分。

所以要把 val 进行左移。

也就是说,实际上只有 根节点val 是真正有意义的,因为只有根节点才记录的是完整的数字的 XOR和。


插入和删除就不提了,看代码即可。注意插入删除合并起来,可以达成单点修改的效果。

最后说一下全局加一。

全局加一指的是将整个01Trie维护的所有数字都加一。

在二进制下的加一,相当于从低位到高位找第一个出现的 $0$,然后将其变成 $1$。

最后,比这一位低的所有位都将会从 $1$ 变成 $0$。

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}
板子
const int M = 21;  // 最大深度
int id = 0;
struct Node {
    int w;  // 到 parent 这条边上的个数
    int val;  // subtree 内的XOR和
    int child[2];
} trie[maxn*(M+1)];

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w++;  // w += 1
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

void del(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w--;  // w -= 1
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}

可持久化01-Trie

和主席树相似,可持久化01-Trie主要是为了让我们获得任何一个区间 $[L,R]$ 内组成的 01-Trie。

使用 Node 的 cnt 来决定 01-Trie 的内容。

板子
const int maxn = 6e5+5;
const int M = 28;
struct Node {
    int cnt;
    int child[2];
} trie[maxn * (M+3)];
int n, q, root[maxn], a[maxn], id;

void insert(int pre, int cur, int x) {
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        trie[cur].cnt = trie[pre].cnt + 1;  // cnt 加一
        if (!trie[cur].child[c]) trie[cur].child[c] = ++id;  // 新建节点
        trie[cur].child[c^1] = trie[pre].child[c^1];  // 复制另外一个子节点
        cur = trie[cur].child[c];
        pre = trie[pre].child[c];
    }
    trie[cur].cnt = trie[pre].cnt + 1;
}

// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
    int res = 0;
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        c ^= 1;
        if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) {  // cnt > 0
            cur = trie[cur].child[c];
            pre = trie[pre].child[c];
            res |= (1<<i);
        } else {
            cur = trie[cur].child[c^1];
            pre = trie[pre].child[c^1];
        }
    }
    return res;
}
int main() {
    for (int i = 1; i <= n; i++) {
        cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
        if (!root[i]) root[i] = ++id;
        insert(root[i-1], root[i], sum[i]);
    }
    // 查询 [L,R] -> (l-1,r)
    int res = query(root[max(0,l-1)], root[r], x);
}

例题

例1 洛谷P4551 最长异或路径

题意

给定一个 $n$ 个节点的树,边上有权值。

寻找树上的两个节点,使得路径上的边权 XOR和 最大,输出这个最大值。

其中,$n \leq 10^5$。

题解

常见套路:设 $f_x$ 为从 $1$(根节点)到 $x$ 的路径上的 XOR和。

那么 $(u,v)$ 路径上的 XOR和 就等于 $f_u \text{ xor } f_v$。因为 $1$ 到 $LCA(u,v)$ 的部分被抵消掉了。

所以就相当于建立一个 01Trie,储存所有的 $f_u$,然后对于每个 $u$,都询问一下最大的XOR即可。

代码
#include <bits/stdc++.h>
using namespace std;
const ll mod = 1e9+7;
const int maxn = 1e5+5;
const int maxm = 30;

struct Edge {
    int to, nxt, w;
} edges[maxn<<1];
int head[maxn], ecnt = 1, n;
void addEdge(int u, int v, int w) {
    Edge e = {v, head[u], w};
    edges[ecnt] = e;
    head[u] = ecnt++;
}
int dp[maxn];
void dfs(int u, int p) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int to = edges[e].to;
        if (to == p) continue;
        dp[to] = dp[u] ^ edges[e].w;
        dfs(to, u);
    }
}

int id = 1;
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt++;
    }
}

int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[c].child[k]) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
        }
    }
    return res;
}


int main() {
    fastio;
    cin >> n;
    for (int i = 1; i <= n-1; i++)     {
        int u,v,w; cin >> u >> v >> w;
        addEdge(u,v,w); addEdge(v,u,w);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++) insert(dp[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        ans = max(ans, query(dp[i]));
    }
    cout << ans << endl;
}

例2 CF817E Choosing The Commander

题意

给定 $q$ 个询问,共有3种:

$1 ~ p$:将 $p$ 加入集合。

$2 ~ p$:将 $p$ 从集合中删除。

$3 ~ p ~ l$:询问集合中有多少个元素在 XOR $p$ 之后小于 $l$。

其中,$q \leq 10^5$。

题解

在 01-Trie 每个节点处加上一个 count 就可以了。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2e5+5;
const int maxm = 27;
 
int Q, id = 1;
struct Node {
    int child[2];
    int cnt = 0;
} trie[maxn<<3];
void insert(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt++;
    }
}
void del(int x) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt--;
    }
}
 
int query(int p, int l) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        if (l & (1<<j)) {  // 如果这一位有,说明要往这一位为 1 的方向走
            int k = ((p & (1<<j)) > 0);
            res += trie[trie[c].child[k]].cnt;
            if (trie[trie[c].child[k^1]].cnt) c = trie[c].child[k^1];
            else break;
        } else {  // 如果这一位没有,就往 0 的方向走
            int k = ((p & (1<<j)) > 0);
            if (trie[trie[c].child[k]].cnt) c = trie[c].child[k];
            else break;
        }
    }
    return res;
}
 
int main() {
    fastio;
    cin >> Q;
    while (Q--) {
        int op, p, l;
        cin >> op >> p;
        if (op == 1) {
            insert(p);
        } else if (op == 2) {
            del(p);
        } else {
            cin >> l;
            int res = query(p,l);
            cout << res << "\n";
        }
    }
}

例3 洛谷P6018 [Ynoi2010] Fusion tree

题意

给定一棵 $n$ 个节点的树,每个节点 $i$ 上有权值,初始权值为 $a_i$。

给出 $m$ 个询问,询问的类型有3种:

$1 ~ x$:将所有与节点 $x$ 距离为 $1$ 的节点的权值加一。

$2 ~ x ~ v$:将节点 $x$ 的权值减去 $v$。

$3 ~ x$:询问所有与节点 $x$ 距离为 $1$ 的节点权值的XOR和。

其中,$n,m \leq 5 \times 10^5$,初始权值 $\leq 10^5$。

题解

这里是 01-Trie 用来维护XOR和。

我们可以对于每一个节点,都把它的孩子(距离为 $1$)的XOR和 维护在一个 01Trie里面。

因为 01-Trie 是动态开点的,所以总共复杂度也就 $n\log(10^5)$。

然后因为每个节点都只有一个 parent,就单独维护即可。

然后对于操作 $1 ~ x$,我们需要做的几件事:

  1. 单独更新 $x$ 的parent。
  2. 给 $x$ 打上一个懒标记,代表它要给它的孩子加上 $lazy$ 的值。
  3. 在 $x$ 维护的 01-Trie 上进行全体加一操作。

然后在每次操作之前,都先检查一下 $parent$ 的懒标记,将其下放。

当然这里的下放不太一样,一个节点可能有非常多个孩子,所以我们选择给每一个节点 $x$ 都定义一个值 add[x],代表它已经从parent那里得到了多少下放的懒标记。

然后就给 $x$ 的权值加上 (lazy[p] - add[x])

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5e5+5;
const int M = 21;  // 最大深度

int id = 0;
struct Node {
    int w;  // 到 parent 这条边上的个数
    int val;  // subtree 内的XOR和
    int child[2];
} trie[maxn*(M+1)];
int root[maxn], lazy[maxn], add[maxn];

void push_up(int cur) {
    trie[cur].val = trie[cur].w = 0;  // 先清空
    int c0 = trie[cur].child[0], c1 = trie[cur].child[1];

    // 更新 w
    if (c0) trie[cur].w += trie[c0].w;
    if (c1) trie[cur].w += trie[c1].w;

    // 更新 val

    if (c0) {
        trie[cur].val ^= (trie[c0].val << 1);
    }
    if (c1) {
        trie[cur].val ^= ((trie[c1].val << 1) | (trie[c1].w & 1));  // 如果 c1 的 w 为奇数,就提供了 1 的贡献
    }
}

// cur: index, x: 当前的数字, dep: 深度
void insert(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w++;
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

void del(int& cur, int x, int dep) {
    if (!cur) cur = ++id;
    if (dep >= M) {
        trie[cur].w--;
        return;
    }
    int c = x & 1;
    insert(trie[cur].child[c], (x>>1), dep+1);
    push_up(cur);
}

// add 1 to all numbers maintained by cur
void addall(int cur) {
    swap(trie[cur].child[0], trie[cur].child[1]);
    if (trie[cur].child[0]) {
        addall(trie[cur].child[0]);
    }
    push_up(cur);
}

int n, m, a[maxn], par[maxn];
struct Edge {
    int to, nxt;
} edges[maxn<<1];
int head[maxn], ecnt = 1;
void addEdge(int u, int v) {
    Edge e = {v, head[u]};
    head[u] = ecnt;
    edges[ecnt++] = e;
}

void dfs(int u) {
    for (int e = head[u]; e; e = edges[e].nxt) {
        int to = edges[e].to;
        if (to == par[u]) continue;
        par[to] = u;
        insert(root[u], a[to], 0);
        dfs(to);
    }
}

// modify u to have the value val
void modify(int u, int val) {
    // a[u] is previous value
    int p = par[u];
    if (p) {
        del(root[p], a[u], 0);
        insert(root[p], val, 0);
    }
    a[u] = val;
}

int main() {
    fastio;
    cin >> n >> m;
    for (int i = 1; i <= n-1; i++) {
        int u,v; cin >> u >> v;
        addEdge(u,v); addEdge(v,u);
    }
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }
    dfs(1);

    while (m--) {
        int op, x; cin >> op >> x;
        if (op == 1) {
            addall(root[x]);
            int p = par[x];
            if (p) {
                a[p] = a[p] + lazy[par[p]] - add[p];
                add[p] = lazy[par[p]];
                modify(par[x], a[p] + 1);
            }
            lazy[x]++;
        }
        if (op == 2) {
            int v; cin >> v;
            a[x] = a[x] + lazy[par[x]] - add[x];
            add[x] = lazy[par[x]];
            modify(x, a[x] - v);
        }
        if (op == 3) {
            int res = trie[root[x]].val;
            int p = par[x];
            if (p) {
                a[p] = a[p] + lazy[par[p]] - add[p];
                add[p] = lazy[par[p]];
                res ^= a[p];
            }
            cout << res << "\n";
        }
    }
}

例4 洛谷P4735 最大异或和

题意

给定一个非负整数序列 $a_1,a_2,…,a_n$。

现给定 $m$ 个操作,有以下两种操作类型:

$A ~ x$:在序列末尾添加一个数 $x$。

$Q ~ L ~ R ~ x$:输出一个位置 $p$,满足 $p \in [L,R]$,使得 $a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x$ 最大。

其中,$n,m \leq 3 \times 10^5, a_i \in [0,10^7]$。

题解

可持久化 01-Trie 的模版题。

首先,先考虑前缀XOR数组 $s_i = a_1 \bigoplus a_2 \bigoplus … \bigoplus a_i$。

那么

$$a_p \bigoplus a_{p+1} \bigoplus … \bigoplus a_n \bigoplus x = (s_n \bigoplus s_{p-1}) \bigoplus x$$

所以现在问题就变成:

每次询问 $Q ~ L ~ R ~ x$,输出一个位置 $p$,满足 $p \in [L-1,R-1]$,使得 $(s_n \bigoplus s_{p-1}) \bigoplus x$ 最大。

那就很简单了,令 $y = s_n \bigoplus x$,剩下的就是在一个区间 $[L-1,R-1]$ 内找出一个元素 $s_p$ 使得 $s_p \bigoplus y$ 最大。

对于每一个区间都可以用 01-Trie 解决的问题,就是可持久化 01-Trie了。


最后需要注意一下这个样例:

4 1
2 4 8 16
Q 1 4 1

答案应为 $31$,但我输出 $29$。

这是因为我们要单独处理一下 $s_0$ 的情况。换而言之我们要将 $s_0 = 0$ 插入到 root[1] 当中。

所以 root[1] 需要包含 $2$ 个元素:$s_0,s_1$。

对于一个版本插入 $2$ 个元素,只要在第二次插入时使用 insert(root[1], root[1], 0) 即可。

最后再注意一下一些边界条件如 $[L-1,R-1]=[0,0]$ 即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 6e5+5;
const int M = 28;

struct Node {
    int cnt;
    int child[2];
} trie[maxn * (M+1)];
int sum[maxn];  // sum[n] 代表 a[1] ^ a[2] ... ^ a[n]
int n, q, root[maxn], a[maxn], id;

void insert(int pre, int cur, int x) {
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        trie[cur].cnt = trie[pre].cnt + 1;  // cnt 加一
        if (!trie[cur].child[c]) trie[cur].child[c] = ++id;  // 新建节点
        trie[cur].child[c^1] = trie[pre].child[c^1];  // 复制另外一个子节点
        cur = trie[cur].child[c];
        pre = trie[pre].child[c];
    }
    trie[cur].cnt = trie[pre].cnt + 1;
}

// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
    int res = 0;
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        c ^= 1;
        if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) {  // cnt > 0
            cur = trie[cur].child[c];
            pre = trie[pre].child[c];
            res |= (1<<i);
        } else {
            cur = trie[cur].child[c^1];
            pre = trie[pre].child[c^1];
        }
    }
    return res;
}


int main() {
    fastio;
    cin >> n >> q;

    for (int i = 1; i <= n; i++) {
        cin >> a[i]; sum[i] = sum[i-1] ^ a[i];
        if (!root[i]) root[i] = ++id;
        insert(root[i-1], root[i], sum[i]);
        if (i == 1) insert(root[1], root[1], 0);
    }
    while (q--) {
        char op; cin >> op;
        if (op == 'A') {
            int x; cin >> x;
            a[++n] = x;
            sum[n] = sum[n-1] ^ a[n];
            if (!root[n]) root[n] = ++id;
            insert(root[n-1], root[n], sum[n]);
        } else {
            int l,r,x; cin >> l >> r >> x;
            l--; r--;
            x ^= (sum[n]);
            if (l == 0 && r == 0) {
                cout << x << "\n"; continue;
            }
            int res = query(root[max(0,l-1)], root[r], x);
            cout << res << "\n";
        }
    }
}

例5 CF241B Friends

题意

给定一个长度为 $n$ 的非负整数数组 $a_1,a_2,…,a_n$。

现在选择 $m$ 个pair $(i,j)$,使得这 $m$ 个pair 对应的 $a_i \text{ xor } a_j$ 值的 sum 最大。

其中,$n \leq 5 \times 10^4, a_i \leq 10^9$,答案对 $10^9+7$ 取模。

• $(i,j)$ 和 $(j,i)$ 算作一个 pair。

题解

神仙题。

本题要求的是前 $m$ 大的XOR pair的和。

我们先列出本题的过程:

  1. 求出第 $m$ 大。
  2. 求出 XOR 小于等于 $x$ 的所有pair的和。
  3. 对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。

第一步:求第 $m$ 大

• 虽然 $(i,j)$ 和 $(j,i)$ 算作一个 pair,但我们只要把所有东西都乘以 $2$ 就行了。所以我们求第 $2m$ 大。

最简单粗暴的方法当然是二分最终的结果 $x$,然后对于每一个 $a_i$,判断一下在 Trie 中有多少个数字和它 XOR 起来 $\geq x$,把这些 count 累加起来然后查看 count >= 2m 与否,来调整 $x$ 的值。

这样的时间复杂度是 $O(n\log^2 10^9)$,本题可过,但是另外一个版本不可过。

于是我们思考一下能不能直接在 Trie 上模拟这个二分的过程来构建最终的结果 $x$。

发现是可以的,我们从高位到低位,枚举每一位取 $1$ 还是 取 $0$。

然后我们对于每一个 $a_i$ 都判断一下 count 的总和,如果 count >= 2m,则说明这一位可以取 $1$。否则的话需要取 $0$,同时问题变成寻找第 $2m - count$ 大(有点类似于主席树寻找第 $k$ 大的思路)。

• 这个过程是在所有 $a_i$ 上进行的,所以我们维护 $n$ 个指针 ptr[] 代表 $a_i$ 目前在 Trie 的什么位置。

复杂度为 $O(n\log 10^9)$。


第二步:求出 XOR 小于等于 $x$ 的所有pair的和

如果这个问题不是求出和,而是求出个数的话,就是例 $2$ 的问题了。

可惜不是,那怎么处理?

一样,我们从每个 $a_i$ 开始进行 check,对于每一个 Trie 上的节点,我们可以维护一个 cnt[31],其中 cnt[j] 代表当前Trie节点的子树里有多少个数字满足第 $j$ 位为 $1$。

这样当然可以算,但是空间复杂度 $O(n\log^2 10^9)$,本题或许可过,但另外一个版本仍然不可过。

我们这里给出一个非常优秀的结论:

如果 $a_i$ 是 sorted 的,那么 Trie 上任意一个节点的子树 对应 $a_i$ 上的一段连续区间。

这个也不难理解,因为 Trie 的一个子树对应的是拥有一个 “特定前缀” (如 “1101”)的所有数字的集合。

有了这个结论以后,我们在枚举 $a_i$ 时,假设我们目前到了第 $j$ 位,就分两种情况:

  1. $x$ 的第 $j$ 位为 $1$:那么我们就往 $1$ 的方向走。
  2. $x$ 的第 $j$ 位为 $0$:我们把 $a_i$ XOR 上 $1$ 的那个子树,求出sum,给答案加上,然后往 $0$ 的方向走。

注意到 $a_i$ XOR 上一个子树的 sum,可以转化成一个区间上的问题,那么就有第三步:


第三步:对于一个区间 $[L,R]$,在 $O(1)$ 求出 $\sum\limits_{i=L}^R y \text{ xor } a_i$。

由于 $a_i$ 总共有 $n$ 个,所以不能直接求前缀和。

但我们可以把区间内的每个数都 拆成二进制

然后对于拆开的二进制,每一位分别 XOR 上 $0$ 和 $1$,然后都进行一个前缀和。

这样给定一个数字 $y$,就可以把 $y$ 拆成二进制,然后对于每一位 $j$,把 $[L,R]$ 内对应的和用 $O(1)$ 时间内求出来即可。

• 所以我们要做的事就是先 sort 一下整个 $a_i$ 数组,然后在构建 Trie 的时候记录一下每个子树对应哪个区间即可。


最后要注意,由于第 $2m$ 大有可能 等于 第 $2m+1, 2m+2 …$ 大,所以要处理掉额外加上的部分。

答案除以 $2$ 即可。

代码
#include <bits/stdc++.h>
using namespace std;

const ll mod = 1e9+7;
const int maxn = 5e4+5;
const int M = 31;
struct Node {
    int cnt = 0;
    int child[2];
    int L = 1e9, R = -1;  // left and right segments
} trie[maxn * (M+1)];
int id = 1;
void insert(int idx, ll x) {
    int c = 1;
    for (int i = M; i >= 0; i--) {
        int k = ((x & (1LL<<i)) ? 1 : 0);
        if (!trie[c].child[k]) trie[c].child[k] = ++id;
        c = trie[c].child[k];
        trie[c].cnt++;
        trie[c].L = min(trie[c].L, idx);  // 维护子树对应的区间
        trie[c].R = max(trie[c].R, idx);
    }
}

ll a[maxn];
ll n,m;

// 寻找第 k 大 (本题是找到第 2m 大)
int ptr[maxn];  // 每一个 ai 对应的 Trie 上的 ptr
ll find_kmax(ll t) {
    ll res = 0;
    fill(ptr+1, ptr+n+1, 1);  // 开始都在 1
    for (int j = M; j >= 0; j--) {
        ll cnt = 0;
        // 找这一位 >= 1 的有多少个
        for (int i = 1; i <= n; i++) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            k ^= 1;
            cnt += trie[trie[ptr[i]].child[k]].cnt;
        }
        int nxt;
        if (cnt >= t) {
            res |= (1LL<<j);
            nxt = 1;
        } else {
            t -= cnt;
            nxt = 0;
        }

        for (int i = 1; i <= n; i++) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            k ^= nxt;  // 由这一步决定的哪一个来判断每一个 ptr 的更新位置
            ptr[i] = trie[ptr[i]].child[k];
        }
    }
    return res;
}


int sum[maxn][M+1][2];
// 计算 a[L,R] ^ x 之和
ll sum_range(ll l, ll r, ll x) {
    ll ans = 0;
    for (int j = M; j >= 0; j--) {
        int k = ((x & (1LL<<j)) ? 1 : 0);
        ll s = sum[r][j][k] - sum[l-1][j][k];
        ans += s * (1LL << j);
        ans %= mod;
    }
    return ans;
}

ll qpow(ll a, ll b) {
    ll res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

// 获得 ai ^ aj >= x 的所有数之和
ll get_sum(ll x) {
    // 先把每个数拆成二进制,做前缀和
    for (int j = M; j >= 0; j--) {
        for (int i = 1; i <= n; i++) {
            for (int k = 0; k <= 1; k++) {
                int c = ((a[i] & (1LL<<j)) ? 1 : 0);
                c ^= k;
                sum[i][j][k] = sum[i-1][j][k] + c;
            }
        }
    }

    ll ans = 2 * x % mod;
    ll cnt = 0;  // 使用的对数
    for (int i = 1; i <= n; i++) {
        int c = 1;
        for (int j = M; j >= 0; j--) {
            int k = ((a[i] & (1LL<<j)) ? 1 : 0);
            int d = ((x & (1LL<<j)) ? 1 : 0);

            if (d == 0) {
                // 加上所有 xor 起来为 1 的部分, 然后走到 0
                int p = trie[c].child[k^1];
                if (p) {
                    ans += sum_range(trie[p].L, trie[p].R, a[i]);
                    cnt += (trie[p].R - trie[p].L + 1);
                    ans %= mod;
                }
                c = trie[c].child[k];
            } else {
                c = trie[c].child[k^1];
            }
        }
    }
    cnt += 2;
    ans = (ans - ((cnt - 2 * m) % mod * x % mod) + mod) % mod;  // 减去重复的部分

    ans = ans * qpow(2, mod-2) % mod;
    return ans;
}

int main() {
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    if (!m) {
        cout << 0 << endl;
        return 0;
    }

    sort(a+1, a+n+1);
    for (int i = 1; i <= n; i++) insert(i, a[i]);

    ll x = find_kmax(2 * m);
    ll ans = get_sum(x);
    cout << ans << endl;
}

例6 CF1625D Binary Spiders

题意

给定 $n$ 个非负整数,和一个非整数 $k$。

求 $n$ 个非负整数的一个 subset,使得这个 subset 最大,并且 subset内 每两个数之间的 XOR $\geq k$。

其中,$n \leq 3 \times 10^5, k,a_i \in [0,2^{30}-1]$。

题解

首先我们设 $k$ 的最高位为 $j$。

那么我们将所有的 $a_i$ 分成三个部分:

  1. 最高位 $\geq j+1$ 的。
  2. 最高位 $= j$ 的。
  3. 最高位 $\leq j-1$ 的。

那么我们可以发现,对于 Case 2, Case 3,各只能最多选出 $1$ 个数来。

那么对于 Case 1 呢?

不如我们只考虑它们 bitmask的 prefix,即 $[j+1,30]$ 的这一部分。

我们注意到,如果两个 prefix 不同的数,一定是可以共存的。

如果两个数的 prefix 相同呢?是 有可能 可以共存的!

更准确的来说,如果两个数的 prefix 相同,我们可以把这部分抵消掉,那么问题就变成了最高位 $\leq j$ 的一些数可以最多选多少个出来共存。

这实际上就是 Case 2 和 Case 3 的并集,我们很容易发现这个并集内,最多可以选 $2$ 个数。

所以问题就简化了,我们只要根据 prefix 分类这些数,然后在每个分类中,选出最多两个数来即可。

而选择最多两个数判断是否 $\geq k$,用 01-Trie 可以轻松解决。


实现的过程:

  1. 先找出 $k$ 的最高位 $j$。
  2. 将所有的数字按照 $[j+1,30]$ 这一部分的prefix分类。
  3. 对于每一类,都构建 01-Trie 判断是否存在两个数的 XOR $\geq k$,存在就选择这两个数,否则只选一个。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 3e5+5;
const int maxm = 30;

int id = 1;  // 注意,从 1 开始
struct Node {
    int cnt = 0;
    int child[2];
} trie[maxn<<4];
void insert(int x, int delta) {
    int c = 1;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        if (!trie[c].child[k]) trie[c].child[k] = ++id;

        c = trie[c].child[k];
        trie[c].cnt += delta;
    }
}

// maximum query
int query(int x) {
    int c = 1;
    int res = 0;
    for (int j = maxm; j >= 0; j--) {
        int k = 0;
        if (x & (1<<j)) k = 1;
        k ^= 1;
        if (trie[trie[c].child[k]].cnt) {
            c = trie[c].child[k];
            res |= (1<<j);
        } else {
            c = trie[c].child[k^1];
            // res |= (1<<j);
        }
    }
    return res;
}

int n,k;
struct Nd {
    int val, id;
} a[maxn];
int b[maxn];

vector<int> ans;
vector<pii> vec[maxn];
int main() {
    fastio;
    cin >> n >> k;
    for (int i = 1; i <= n; i++) cin >> a[i].val, a[i].id = i;
    sort(a+1, a+n+1, [](auto a, auto b) {
        return a.val > b.val;
    });

    if (k == 0) {
        cout << n << endl;
        for (int i = 1; i <= n; i++) cout << i << " ";
        cout << endl;
        return 0;
    }
    int p = 0;
    for (int j = 0; j <= 30; j++) {
        if (k & (1<<j)) p = j;
    }

    int ptr = 0;
    for (int i = 1; i <= n; i++) {
        b[i] = a[i].val;
        for (int j = 0; j <= p; j++) {
            if (b[i] & (1<<j)) b[i] ^= (1<<j);
        }
        if (!(i > 1 && b[i] == b[i-1])) {
            ++ptr;
        }
        vec[ptr].push_back({a[i].val, a[i].id});
    }

    for (int i = 1; i <= ptr; i++) {
        if (vec[i].size() >= 2) {
            bool ok = 0;
            for (int j = 0; j < vec[i].size(); j++) {
                if (!ok && query(vec[i][j].first) >= k) {
                    for (int a = 0; a < j; a++) {
                        if ((vec[i][j].first ^ vec[i][a].first) >= k) {
                            ans.push_back(vec[i][j].second);
                            ans.push_back(vec[i][a].second);
                            ok = 1;
                            break;
                        }
                    }
                }
                insert(vec[i][j].first, 1);
            }
            if (!ok) {
                ans.push_back(vec[i][0].second);
            }
            for (int j = 0; j < vec[i].size(); j++) {
                insert(vec[i][j].first, -1);
            }
        } else if (vec[i].size() == 1) {
            ans.push_back(vec[i][0].second);
        }
    }


    if (ans.size() <= 1) cout << -1 << endl;
    else {
        cout << ans.size() << endl;
        for (int j : ans) cout << j << " ";
        cout << endl;
    }
}

例7 CF1665E. MinimizOR

题意

给定 $n$ 个非负整数 $a_1,a_2,…,a_n$ 和 $q$ 个询问,每次询问 $[l_i, r_i]$ 之间,选出两个index $j \neq k$,使得 $j,k \in [l_i, r_i]$ 且 $a_j | a_k$ 最小,求最小值。

其中,$n \in [2, 10^5], a_i \in [0, 2^{30}), q \in [1, 10^5], 1 \leq l_i < r_i \leq n$。

题解

可持久化01-trie。

可持久化保证了我们可以对于每一个区间找 $a_j | a_k$ 最小值。

怎么找呢?

考虑从 01-trie 上开始 dfs,如果发现 $0$ 对应的子树中有 $\geq 2$ 个数字,那么说明这一位可以取 $0$,否则只能取 $1$。

但有的情况下,$0$ 对应的子树中只有 $1$ 个数字 $a$,此时我们只能取 $1$,但是这个数字 $a$ 也是要拿出来的,因为它有可能在下一层被用到,所以我们单独用一个 vector<int> tmp 来储存这些 $a$。

最后注意一下,当我们在某一层取了 $0$ 时,要重新判断一下 tmp 中的数字 $a$ 在这一位是否为 $1$,如果是,说明最终答案中一定不可能考虑到这个数字 $a$,扔掉即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int M = 30;
struct Node {
    int cnt;
    int child[2];
} trie[maxn * (M+1)];
int n, root[maxn], id;
void insert(int pre, int cur, int x) {
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        trie[cur].cnt = trie[pre].cnt + 1;  // cnt 加一
        if (!trie[cur].child[c]) trie[cur].child[c] = ++id;  // 新建节点
        trie[cur].child[c^1] = trie[pre].child[c^1];  // 复制另外一个子节点
        cur = trie[cur].child[c];
        pre = trie[pre].child[c];
    }
    trie[cur].cnt = trie[pre].cnt + 1;
}
// find the maximum value for a ^ x (a is in [pre, cur])
int query(int pre, int cur, int x) {
    int res = 0;
    for (int i = M; i >= 0; i--) {
        int c = 0;
        if (x & (1 << i)) c = 1;
        c ^= 1;
        if (trie[trie[cur].child[c]].cnt - trie[trie[pre].child[c]].cnt) {  // cnt > 0
            cur = trie[cur].child[c];
            pre = trie[pre].child[c];
            res |= (1<<i);
        } else {
            cur = trie[cur].child[c^1];
            pre = trie[pre].child[c^1];
        }
    }
    return res;
}
 
int ans = 0;
vector<int> tmp, tmp2;
void dfs(int pre, int cur, int d) {
    if (d < 0) return;
    int cnt0 = 0, cnt1 = 0;
    cnt0 = trie[trie[cur].child[0]].cnt - trie[trie[pre].child[0]].cnt;
    cnt1 = trie[trie[cur].child[1]].cnt - trie[trie[pre].child[1]].cnt;
    for (int j : tmp) {
        if (!(j & (1<<d))) cnt0++;
    }
 
    if (cnt0 >= 2) {
        tmp2.clear();  // 如果为0,这里只能保留这一位为0的!
        for (int j : tmp) {
            if (!(j & (1<<d))) tmp2.push_back(j);
        }
        tmp = tmp2;
        dfs(trie[pre].child[0], trie[cur].child[0], d-1);
    } else {
        if (trie[trie[cur].child[0]].cnt - trie[trie[pre].child[0]].cnt) {
            assert(trie[trie[cur].child[0]].cnt - trie[trie[pre].child[0]].cnt == 1);
 
            int tcur = trie[cur].child[0], tpre = trie[pre].child[0], tx = ans;
            for (int j = d-1; j >= 0; j--) {
                if (trie[trie[tcur].child[0]].cnt - trie[trie[tpre].child[0]].cnt) {
                    tcur = trie[tcur].child[0];
                    tpre = trie[tpre].child[0];
                } else {
                    tcur = trie[tcur].child[1];
                    tpre = trie[tpre].child[1];
                    tx |= (1<<j);
                }
            }
            tmp.push_back(tx);
        }
        ans |= (1<<d);
        dfs(trie[pre].child[1], trie[cur].child[1], d-1);
    }
}
 
void clear(int cur) {
    for (int c = 0; c <= 1; c++) {
        if (trie[cur].child[c]) clear(trie[cur].child[c]);
    }
    memset(&trie[cur], 0, sizeof(trie[cur]));
}
void clearall() {
    id = 0;
    for (int i = 1; i <= n; i++) {
        if (root[i]) clear(root[i]), root[i] = 0;
    }
}
 
void Query(int l, int r) {
    ans = 0;
    tmp.clear();
    dfs(root[l-1], root[r], M);
    cout << ans << endl;
}
 
int T, a[maxn];
int main() {
    cin >> T;
    while (T--) {
        cin >> n;
        for (int i = 1; i <= n; i++) {
            cin >> a[i];
            if (!root[i]) root[i] = ++id;
            insert(root[i-1], root[i], a[i]);
        }
        int q; cin >> q;
        while (q--) {
            int l, r; cin >> l >> r;
            Query(l, r);
        }
        clearall();
    }
}

例8 UniversalCup 12 N.XOR Reachable

题意

给定 $n$ 个点,$m$ 条边的带权无向图。给定一个非负整数 $K$。

现在给定 $Q$ 个询问,每次询问给定一个非负整数 $D$。对于这次询问,我们保留所有权值为 $C_i$ 的边使得 $C_i \text{ xor } D < K$,然后回答在这个图中,有多少个pair $u,v$ 使得 $u$ 和 $v$ 能够互相到达。

其中,$1 \leq n,m,Q \leq 10^5, K,C_i,D_j \in [0, 2^{30})$。

询问之间互相独立,并且图中可能有重边(无自环)。

题解

首先可以看出要用到 01-Trie。

我们可以考虑将所有的边都存进 01-Trie 里面,然后对于每一个询问我们可以执行一次 DFS 找到所有这样的边。

计算有多少个 $u,v$ 可以互相到达,肯定是维护一个并查集,然后求每个联通块的大小。

但这样复杂度太高了,我们考虑怎么样只跑一次 DFS 就可以回答所有的询问?


发现本题是离线的,我们不妨把所有的询问也存进 01-Trie 里面,并且只存在叶子节点。

然后我们只跑一次 DFS,当我们 DFS 到一个叶子节点的时候,就回答这个叶子节点储存的所有询问。

怎么 DFS 呢?

我们考虑 $K$ 的每一位,我们现在 DFS 到了第 $i$ 位,假设 $K$ 的第 $i$ 位是 $1$。

那么我们在 DFS 的过程中:

  1. 对于所有第 $i$ 位为 $1$ 的询问 $D$,说明我们应该加上所有第 $i$ 位为 $1$ 的边 $C$,然后继续 DFS 考虑所有第 $i$ 位为 $0$ 的边。
  2. 对于所有第 $i$ 位为 $0$ 的询问 $D$,说明我们应该加上所有第 $i$ 位为 $0$ 的边 $C$,然后继续 DFS 考虑所有第 $i$ 位为 $1$ 的边。

• 但这样会发现一个问题,我们希望的是 DFS 到询问所在的叶子,在 DFS 过程中加上边,但这样的话需要 DFS 的询问和边不在同一棵子树里,我们无法在知道具体是哪个询问的同时,知道我们加上了哪些边。

所以我们可以先将所有询问 $D$ 执行 D ^= K,相当于将 $K$ 所有为 $1$ 的位置,都在 $D$ 中翻转了一下。

这样我们就得到

  1. 对于所有第 $i$ 位为 $1$ 的询问 $D$,说明我们应该加上所有第 $i$ 位为 $0$ 的边 $C$,然后继续 DFS 考虑所有第 $i$ 位为 $1$ 的边。
  2. 对于所有第 $i$ 位为 $0$ 的询问 $D$,说明我们应该加上所有第 $i$ 位为 $1$ 的边 $C$,然后继续 DFS 考虑所有第 $i$ 位为 $0$ 的边。

这样就没问题了。

同理,如果 $K$ 的第 $i$ 位是 $0$,我们不需要加上某一棵子树中的边,直接 DFS 即可,由于 $K$ 的第 $i$ 位是 $0$ 所以 D ^= K 的操作没有影响。

  1. 对于所有第 $i$ 位为 $1$ 的询问 $D$,DFS 考虑所有第 $i$ 位为 $1$ 的边。
  2. 对于所有第 $i$ 位为 $0$ 的询问 $D$,DFS 考虑所有第 $i$ 位为 $0$ 的边。

• 注意 DFS 回溯时,还需要把加上的边撤销掉,这个利用可撤销并查集来做就可以了,每次操作时间复杂度 $O(\log n)$。


时间复杂度是多少?

我们只执行一次 DFS,复杂度为 $O(30n)$。

对于每一条边,它只会被它的所有祖先加上,所以总共加上 + 撤销的操作也是 $O(30n)$,总时间复杂度就是 $O(30n \log n)$。

• 最后注意,DFS 的时候如果深度为 $-1$ 说明我们到达了叶子节点(不是深度为 $0$!)。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;

int n, m, K, Q;
struct Edge {
    int from, to, w;
};
vector<Edge> adj[maxn];

struct State {
    int u, v, szu, szv;
} st[maxm];

ll ans = 0, res[maxn];
ll cal(ll x) {
    return x*(x-1) / 2;
}

struct DSU {
    int par[maxn], sz[maxn], tail = 0;
    inline void init() {
        for (int i = 1; i <= n; i++) par[i] = i, sz[i] = 1;
        tail = 0;
    }

    int finds(int u) {
        if (par[u] == u) return u;
        return finds(par[u]);
    }

    void unions(int u, int v) {
        u = finds(u), v = finds(v);
        if (sz[u] < sz[v]) swap(u,v);  // sz[u] >= sz[v]
        st[++tail] = {u, v, sz[u], sz[v]};
        if (u == v) return;
        par[v] = u;
        ans = ans - cal(sz[u]) - cal(sz[v]) + cal(sz[u] + sz[v]);
        sz[u] += sz[v];
    }

    void cancel() {
        if (tail > 0) {
            int u = st[tail].u, v = st[tail].v;
            par[v] = v;
            if (sz[u] != st[tail].szu) {
                assert(sz[u] == st[tail].szu + st[tail].szv);
                ans = ans - cal(sz[u]) + cal(st[tail].szu) + cal(st[tail].szv);
            }
            sz[u] = st[tail].szu;
            sz[v] = st[tail].szv;
            tail--;
        }
    }
} dsu;

struct Node {
    int cnt = 0;
    int child[2];
    vector<Edge> vec;
    vector<int> que;  // queries 的编号
};

const int M = 30;
bool tag[33];
struct Trie01 {
    int id = 1;  // 注意,从 1 开始
    Node trie[maxn<<5];

    void insert(Edge e) {
        int x = e.w;
        int c = 1;
        for (int j = M; j >= 0; j--) {
            int k = 0;
            if (x & (1<<j)) k = 1;
            if (!trie[c].child[k]) trie[c].child[k] = ++id;
            c = trie[c].child[k];
            trie[c].cnt++;
            trie[c].vec.push_back(e);
        }
    }

    void insert(int x, int i) {
        int c = 1;
        for (int j = M; j >= 0; j--) {
            int k = 0;
            if (x & (1<<j)) k = 1;
            if (!trie[c].child[k]) trie[c].child[k] = ++id;
            c = trie[c].child[k];
            trie[c].cnt++;
        }
        trie[c].que.push_back(i);  // 只保留叶子即可
    }

    void addall(int u) {
        for (Edge e : trie[u].vec) {
            dsu.unions(e.from, e.to);
        }
    }

    void clearall(int u) {
        for (int i = 0; i < trie[u].vec.size(); i++) dsu.cancel();
    }

    // 从 d = M 开始 dfs
    void dfs(int u, int d) {
        if (!u) return;
        if (d == -1) {  // 走完了,开始统计答案
            for (int id : trie[u].que) res[id] = ans;
            return;
        }
        int c[2];
        c[0] = trie[u].child[0], c[1] = trie[u].child[1];

        if (tag[d]) {  // 需要加一边,然后dfs另外一边
            for (int o = 0; o < 2; o++) {
                addall(c[o^1]);
                dfs(c[o], d-1);
                clearall(c[o^1]);
            }
        } else {
            dfs(c[0], d-1);
            dfs(c[1], d-1);
        }
    }
} tr;


int main() {
    fastio;
    cin >> n >> m >> K;
    for (int i = 1; i <= m; i++) {
        int u, v, w; cin >> u >> v >> w;
        adj[u].push_back({u,v,w});
        adj[v].push_back({v,u,w});
        tr.insert({u,v,w});
    }
    cin >> Q;
    for (int i = 1; i <= Q; i++) {
        int D; cin >> D; 
        D ^= K;
        tr.insert(D, i);
    }
    dsu.init();
    for (int i = M; i >= 0; i--) tag[i] = (K & (1<<i));

    tr.dfs(1, M);
    for (int i = 1; i <= Q; i++) cout << res[i] << "\n";
}

参考链接

  1. https://oi-wiki.org/string/trie/#_6
  2. https://oi-wiki.org/ds/persistent-trie/