介绍

莫队算法是一种基于分块思想的暴力算法,一般应用于同时满足以下条件的区间问题中:

  1. 已知 $[L,R]$ 之间的答案,能在 $O(1)$ 时间内转移到 $[L+1,R], [L-1,R], [L,R+1], [L,R-1]$ 的答案。
  2. 所有询问均离线。
  3. 不存在修改。

我们用模版举个例子:

题意

给定一个长度为 $N$ 的正整数序列 $a$,给定一个 $k$,满足 $\forall i, a_i \in [1,k]$。

现在有 $M$ 个询问,每个询问给定一个区间 $[l,r]$,求 $\sum_{i=1}^kc_i^2$

其中 $c_i$ 为数字 $i$ 在 $[l,r]$ 中的出现次数。

数据范围:$1 \leq n,m,k \leq 5\times10^4$

算法

Part1 O(1)的状态转移

对于上面的例题,我们可以发现从 $[L,R]$ 转移到 $[L,R+1]$ 是 $O(1)$ 的。

我们维护两个指针 $l,r$,并且维护一个 cnt[] 数组来记录当前区间的 $c_i$,在 $r$ 右移一格的时候,加上对应的 $cnt$,然后要维护的 $\sum_{i=1}^kc_i^2$ 也很好转移,计算一下,就会得到

$$s_{l,r+1} = s_{l,r} + 2\times c_{a_{r+1}} + 1$$

同理对于其他三种情况,转移都是 $O(1)$ 的。

所以,假设我们有两个询问 $[L_1, R_1], [L_2, R_2]$,我们在询问完 $[L_1, R_1]$ 后,将左右指针一个个移动到 $[L_2, R_2]$ 似乎就可以节省一点时间了。(如果它们离得比较近的话)

Part2 莫队思路

既然我们可以通过维护两个指针 $l,r$ 来快速转移,我们又事先知道所有的询问(因为询问离线),那有什么办法将这些询问靠近一些,来节省更多时间呢?

分块思想!

我们将区间划分为 $\sqrt n$ 块,然后对于每个询问 $[L_i,R_i]$,我们根据 $L_i$ 的值,把它放进对应的块中。

然后,我们将所有的询问首先根据 所在块的编号 来sort,对于同一块内的询问,根据 $R_i$ 从小到大 来sort。

最后,根据sort的顺序来处理每个询问,询问之间的转移 就按照上面的左右指针移动来处理。这样我们能在 $O(n\sqrt n)$ 时间内处理好每一个区间。

算法步骤

  1. 预处理所有询问,记录询问的 l,r,记录 be (代表 l 对应是哪个块),记录 id(代表原先是第几个询问)。
  2. 根据 be 作为第一关键字,r 作为第二关键字进行sort。
  3. 定义global variable int l = 1, r = 0, ans = 0
  4. 按照sort后的顺序进行询问,调整 l,r 指针,并相应更新 ans,然后将 ans 根据 id 放入答案数组中。

需要注意的点

  1. 注意在转移过程中,使用的是 --l 还是 l++r 还有更新 ans 的时候也类似,要根据具体情况来看。
  2. 注意初始情况下, l = 1, r = 0
  3. be 是根据 l 的位置决定的。

复杂度证明

先注意:

  1. 同一个块内的 $L_i$ 并没有顺序。
  2. 同一个块内的 $R_i$ 没有限制,可以横跨整个区间。

左指针在块内移动 的复杂度:注意到,同一个块内的 $L_i$ 并没有顺序,所以每次询问可能有 $O(B)$ 的复杂度($B$为块的大小)。总复杂度为 $O(mB)$

右指针在块内移动 的复杂度:因为是 $R_i$ 是有序的,所以在同一个块内移动的总复杂度为 $O(n)$

左指针在块之间移动 的复杂度:每次移动复杂度为 $O(2B)$。总复杂度为 $O(\frac{n}{B} * 2B) = O(2n)$

右指针在块之间移动 的复杂度:总共有 $\frac{n}{B}$ 个块,每次在块之间移动没有限制,复杂度为 $O(n)$。总复杂度为 $O(\frac{n^2}{B})$

综上,复杂度为 $O(mB) + O(n) + O(\frac{n^2}{B})$

当我们取 $B = \sqrt n$ 时,复杂度为 $O(n \sqrt n)$

• 实际上最优复杂度应该取 $B = \frac{n}{\sqrt m}$,总复杂度为 $O(n \sqrt m)$

例题

例1 小B的询问

就是上面的例题,这里直接放代码。

代码
using namespace std;
#include <bits/stdc++.h>
#define ll long long

const int maxn = 5e4+5;
const int maxm = 5e4+5;

int n,m,k;
int sz;
int arr[maxn];
ll cnt[maxn];
ll ans[maxn];
struct query {
    int l,r,be,id;
} q[maxm];

bool cmp(query& a, query& b) {
    if (a.be == b.be) return a.r < b.r;
    return a.be < b.be;
}

int l = 1,r = 0;

ll add(int x) {
    ll res = 2LL * cnt[arr[x]] + 1LL;
    cnt[arr[x]]++;
    return res;
}

ll del(int x) {
    ll res = -2LL * cnt[arr[x]] + 1LL;
    cnt[arr[x]]--;
    return res;
}

int main() {
    cin >> n >> m >> k;
    sz = sqrt(n);
    for (int i = 1; i <= n; i++) cin >> arr[i];
    for (int i = 1; i <= m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].id = i;
        q[i].be = (q[i].l-1) / sz;
    }
    sort(q+1, q+m+1, cmp);

    ll res = 0;
    for (int i = 1; i <= m; i++) {
        int ql = q[i].l, qr = q[i].r;
        while (r < qr) res += add(++r);
        while (r > qr) res += del(r--);
        while (l < ql) res += del(l++);
        while (l > ql) res += add(--l);
        ans[q[i].id] = res;
    }
    for (int i = 1; i <= m; i++) cout << ans[i] << "\n";
}

例2 小Z的袜子

题意

有 $N$ 个袜子,每个袜子 $i$ 有一个颜色 $c_i$,给定 $M$ 个询问 $[L,R]$,每次询问回答 $[L,R]$ 区间内随机抽两个袜子,颜色相同的概率?

其中 $N,M \leq 50000, c_i \in [1,N]$

题解

维护 分子和分母

每次区间长度加 $1$:分母增加 $len$($len$ 为增加前的区间长度),分子增加 $cnt_{c_i}$ ($cnt_{c_i}$ 为新增的颜色 $c_i$ 原来的数量)。

每次区间长度减 $1$:分母减少 $len-1$($len$ 为减少前的区间长度),分子减少 $cnt_{c_i} - 1$ ($cnt_{c_i}$ 为减少的颜色 $c_i$ 原来的数量)。

代码
#include <bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn = 5e4+5;
const int maxm = 5e4+5;

int n,m;
struct query {
    int l,r,be,id;
    ll nu,de;
} q[maxm];

int arr[maxn];
int cnt[maxn];

bool cmp(query& a, query& b) {
    if (a.be == b.be) {
        return a.r < b.r;
    }
    return a.be < b.be;
}

ll nu = 0, de = 0;

ll gcd(ll a, ll b) {
    if (!b) return a;
    return gcd(b, a%b);
}
int sz;

int main() {
    scanf("%d%d",&n,&m);
    sz = sqrt(n);
    for (int i = 1; i <= n; i++) scanf("%d",&arr[i]);
    for (int i = 1; i <= m; i++) {
        int l,r; scanf("%d%d",&l,&r);
        q[i].id = i;
        q[i].be = (l-1)/sz;
        q[i].l = l, q[i].r = r;
    }
    sort(q+1, q+m+1, cmp);

    ll l = 1, r = 0;
    for (int i = 1; i <= m; i++) {
        int ql = q[i].l, qr = q[i].r, id = q[i].id;
        if (ql == qr) {
            q[id].nu = 0, q[id].de = 1;
            continue;
        }
        while (r < qr) de += (r-l+1), r++, nu += cnt[arr[r]], cnt[arr[r]]++;
        while (r > qr) de -= (r-l), nu -= (cnt[arr[r]] - 1), cnt[arr[r]]--, r--;
        while (l > ql) de += (r-l+1), l--, nu += cnt[arr[l]], cnt[arr[l]]++;
        while (l < ql) de -= (r-l), nu -= (cnt[arr[l]] - 1), cnt[arr[l]]--, l++;
        q[id].nu = nu, q[id].de = de;
    }

    for (int i = 1; i <= m; i++) {
        nu = q[i].nu, de = q[i].de;
        if (nu == 0) {
            printf("0/1\n");
            continue;
        }
        ll g = gcd(nu,de);
        nu /= g, de /= g;
        printf("%lld/%lld\n",nu,de);
    }
}

例3 CF617E

题意

给定 $n$ 个整数 $a_1,a_2,…,a_n$,还有一个整数 $k$ ,以及 $m$ 个询问 $[l,r]$,每次询问求 有多少个$i,j$ 满足:

  1. $l \leq i \leq j \leq r$
  2. $a_i \text{ xor } a_{i+1} \text{ xor } … \text{ xor } a_j = k$

其中,$1 \leq n,m \leq 10^5, 0 \leq k \leq 10^6, 0 \leq a_i \leq 10^6$

题解

首先定义一个前缀 $\text{ xor }$ 数组满足 $s_i = a_1 \text{ xor } a_2 \text{ xor } … \text{ xor } a_i$,这样问题转化为:

每次询问求 有多少个$i,j$ 满足:

  1. $l \leq i \leq j \leq r$
  2. $s_i \text{ xor } s_j = k$

注意到,$s_j = s_i \text{ xor } k$,所以我们可以维护一个 cnt[] 数组,记录一下当前区间每个元素出现了多少次。

然后,比如在区间扩张的过程中,就检查 cnt[] 中当前元素 cur 出现的次数,给 ans 加上,然后 cnt[cur ^ k]++;

代码
using namespace std;
#include <bits/stdc++.h>

#define ll long long

const int mod = 1e9+7;
const int maxn = 1e5+5;
const int maxm = 1e6+5;

int n,m,k;
struct query {
    int l,r,id,be;
} q[maxn];

bool cmp(query a, query b) {
    if (a.be == b.be) return a.r < b.r;
    return a.be < b.be;
}

int l = 0, r = -1;
int cnt[2*maxm];
int s[maxn];
ll b[maxn];
double start;

ll ans = 0;

void add(int x) {
    ans += (ll)cnt[s[x]];
    cnt[s[x] ^ k]++;
}

void del(int x) {
    cnt[s[x] ^ k]--;
    ans -= (ll)cnt[s[x]];
}

void ask(int L, int R) {
    while (r < R) add(++r);
    while (r > R) del(r--);
    while (l < L) del(l++);
    while (l > L) add(--l);
}

int main() {
    fastio;
    cin >> n >> m >> k;
    for (int i = 1; i <= n; i++) cin >> s[i];
    for (int i = 1; i <= n; i++) s[i] ^= s[i-1];
    int sz = sqrt(n);

    for (int i = 1; i <= m; i++) {
        cin >> q[i].l >> q[i].r;
        q[i].l--;
        q[i].id = i;
        q[i].be = (q[i].l-1)/sz;
    }
    sort(q+1, q+m+1, cmp);

    for (int i = 1; i <= m; i++) {
        ask(q[i].l, q[i].r);
        b[q[i].id] = ans;
    }

    for (int i = 1; i <= m; i++) cout << b[i] << "\n";
}

带修莫队

略。例题参考:https://www.luogu.com.cn/problem/P1903

参考链接

  1. https://ouuan.github.io/post/%E8%8E%AB%E9%98%9F%E5%B8%A6%E4%BF%AE%E8%8E%AB%E9%98%9F%E6%A0%91%E4%B8%8A%E8%8E%AB%E9%98%9F%E8%AF%A6%E8%A7%A3/#%E5%B8%A6%E4%BF%AE%E8%8E%AB%E9%98%9F
  2. https://www.cnblogs.com/WAMonster/p/10118934.html