介绍

二维ST表就是在一维的ST表上加一个维度,这样可以 $O(nm\log n \log m)$ 内预处理以后,$O(1)$ 询问一个矩阵的最大/最小值。

例题

例1 洛谷P2216[HAOI2007]理想的正方形

题意

给定一个 $n \times m$ 的矩阵,每个位置有一个非负整数,给定 $k$,找出一个 $k \times k$ 的正方形使得正方形内最大值和最小值的差值最小。

其中,$n,m \leq 1000$。

题解

二维 ST 表,由于是找正方形,所以只用维护额外一个维度,即:

st[i][j][k] 代表以 $(i,j)$ 为左上角,以 $(i+2^k-1, j+2^k-1)$ 为右下角的矩阵的最大/最小值。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3+5;

int st1[maxn][maxn][12], st2[maxn][maxn][12];
int bin[maxn];
int a[maxn][maxn];
 
int n, m, k;
// 查询以 (i,j) 为左上角的 k*k 的矩阵中的最大值 - 最小值
int ask_st(int i, int j) {
    int l = bin[k];
    int res1 = max({st1[i][j][l], st1[i+k-(1<<l)][j][l], st1[i][j+k-(1<<l)][l], st1[i+k-(1<<l)][j+k-(1<<l)][l]});
    int res2 = min({st2[i][j][l], st2[i+k-(1<<l)][j][l], st2[i][j+k-(1<<l)][l], st2[i+k-(1<<l)][j+k-(1<<l)][l]});
    return res1 - res2;
}

void build_st() {
    bin[1] = 0; bin[2] = 1;
    for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) st1[i][j][0] = st2[i][j][0] = a[i][j];
    }
    for (int k = 1; k < 12; k++) {
        for (int i = 1; i + (1<<k) - 1 <= n; i++) {
            for (int j = 1; j + (1<<k) - 1 <= m; j++) {
                st1[i][j][k] = max({st1[i][j][k-1], st1[i+(1<<(k-1))][j][k-1], st1[i][j+(1<<(k-1))][k-1], st1[i+(1<<(k-1))][j+(1<<(k-1))][k-1]});
                st2[i][j][k] = min({st2[i][j][k-1], st2[i+(1<<(k-1))][j][k-1], st2[i][j+(1<<(k-1))][k-1], st2[i+(1<<(k-1))][j+(1<<(k-1))][k-1]});
            }
        }
    }
}


int main() {
    cin >> n >> m >> k;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) cin >> a[i][j];
    }
    build_st();
    int ans = 1e9;
    for (int i = 1; i + k - 1 <= n; i++) {
        for (int j = 1; j + k - 1 <= m; j++) {
            ans = min(ans, ask_st(i, j));
        }
    }
    cout << ans << endl;
}

例2 CF713D. Animals and Puzzle

题意

给定一个 $n \times m$ 的01矩阵。

给定 $t$ 个询问,每次询问一个矩阵 $x_1,y_1,x_2,y_2$,回答以 $(x_1,y_1)$ 为左上角,$(x_2,y_2)$ 为右下角的矩阵中,最大的全 $1$ 正方形的边长。

其中,$n,m \leq 1000, t \leq 10^6$。

题解

首先我们可以利用二分处理出一个数组 dp[i][j],代表以 $(i,j)$ 作为左上角,最大的全 $1$ 正方形的边长。

然后对于每次询问 $x_1,y_1,x_2,y_2$,二分一下答案,对于二分到的答案 $k$,我们只需要check一下 $(x_1,y_1,x_2-k+1,y_2-k+1)$ 这个矩阵中,dp 的最大值即可,用二维ST表即可实现 $O(1)$ 查询。

所以最后总复杂度是 $O(nm \log(n) \log(m))$。

• 记得将 ST 表开成 short,否则炸内存。

• 记得开快读,否则头都给T飞。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3+3;

int n, m, a[maxn][maxn], sum[maxn][maxn], dp[maxn][maxn];
int getsum(int i1, int j1, int i2, int j2) {
    return sum[i2][j2] - sum[i2][j1-1] - sum[i1-1][j2] + sum[i1-1][j1-1];
}
short st[maxn][maxn][12][12], bin[maxn];
int ask_st(int i1, int j1, int i2, int j2) {
    int k1 = bin[i2-i1+1], k2 = bin[j2-j1+1];
    return max({st[i1][j1][k1][k2], st[i2-(1<<k1)+1][j1][k1][k2], st[i1][j2-(1<<k2)+1][k1][k2], st[i2-(1<<k1)+1][j2-(1<<k2)+1][k1][k2]});
}

void build_st() {
    bin[1] = 0; bin[2] = 1;
    for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) st[i][j][0][0] = dp[i][j];
    }
    for (int k1 = 1; k1 < 12; k1++) {
        for (int i = 1; i + (1<<k1) - 1 <= n; i++) {
            for (int j = 1; j <= m; j++) {
                st[i][j][k1][0] = max(st[i][j][k1-1][0], st[i+(1<<(k1-1))][j][k1-1][0]);
            }
        }
    }
    for (int k2 = 1; k2 < 12; k2++) {
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j + (1<<k2) - 1 <= m; j++) {
                st[i][j][0][k2] = max(st[i][j][0][k2-1], st[i][j+(1<<(k2-1))][0][k2-1]);
            }
        }
    }

    for (int k1 = 1; k1 < 12; k1++) {
        for (int k2 = 1; k2 < 12; k2++) {
            for (int i = 1; i + (1<<k1) - 1 <= n; i++) {
                for (int j = 1; j + (1<<k2) - 1 <= m; j++) {
                    st[i][j][k1][k2] = max({st[i][j][k1-1][k2-1], st[i+(1<<(k1-1))][j][k1-1][k2], 
                        st[i][j+(1<<(k2-1))][k1][k2-1], st[i+(1<<(k1-1))][j+(1<<(k2-1))][k1-1][k2-1]});
                }
            }
        }
    }
}

int main() {
    fastio;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) cin >> a[i][j];
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            sum[i][j] = sum[i-1][j] + sum[i][j-1] - sum[i-1][j-1] + a[i][j];
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            if (!a[i][j]) dp[i][j] = 0;
            else {
                int low = 1, high = min(n-i+1, m-j+1), res = 1;
                while (low <= high) {
                    int mid = (low+high) >> 1;
                    if (getsum(i, j, i+mid-1, j+mid-1) == mid*mid) {
                        res = mid;
                        low = mid+1;
                    } else high = mid - 1;
                }
                dp[i][j] = res;
            }
        }
    }
    build_st();

    int t; cin >> t;
    while (t--) {
        int i1, j1, i2, j2;
        cin >> i1 >> j1 >> i2 >> j2;
        if (!getsum(i1,j1,i2,j2)) {
            cout << 0 << "\n";
            continue;
        }
        int low = 1, high = min(i2-i1+1, j2-j1+1), res = 1;
        while (low <= high) {
            int k = (low + high) >> 1;
            if (ask_st(i1,j1,i2-k+1,j2-k+1) >= k) {
                res = k;
                low = k+1;
            } else high = k-1;
        }
        cout << res << "\n";
    }
}