二维ST表
Contents
介绍
二维ST表就是在一维的ST表上加一个维度,这样可以 $O(nm\log n \log m)$ 内预处理以后,$O(1)$ 询问一个矩阵的最大/最小值。
例题
例1 洛谷P2216[HAOI2007]理想的正方形
题意
给定一个 $n \times m$ 的矩阵,每个位置有一个非负整数,给定 $k$,找出一个 $k \times k$ 的正方形使得正方形内最大值和最小值的差值最小。
其中,$n,m \leq 1000$。
题解
二维 ST 表,由于是找正方形,所以只用维护额外一个维度,即:
st[i][j][k]
代表以 $(i,j)$ 为左上角,以 $(i+2^k-1, j+2^k-1)$ 为右下角的矩阵的最大/最小值。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3+5;
int st1[maxn][maxn][12], st2[maxn][maxn][12];
int bin[maxn];
int a[maxn][maxn];
int n, m, k;
// 查询以 (i,j) 为左上角的 k*k 的矩阵中的最大值 - 最小值
int ask_st(int i, int j) {
int l = bin[k];
int res1 = max({st1[i][j][l], st1[i+k-(1<<l)][j][l], st1[i][j+k-(1<<l)][l], st1[i+k-(1<<l)][j+k-(1<<l)][l]});
int res2 = min({st2[i][j][l], st2[i+k-(1<<l)][j][l], st2[i][j+k-(1<<l)][l], st2[i+k-(1<<l)][j+k-(1<<l)][l]});
return res1 - res2;
}
void build_st() {
bin[1] = 0; bin[2] = 1;
for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) st1[i][j][0] = st2[i][j][0] = a[i][j];
}
for (int k = 1; k < 12; k++) {
for (int i = 1; i + (1<<k) - 1 <= n; i++) {
for (int j = 1; j + (1<<k) - 1 <= m; j++) {
st1[i][j][k] = max({st1[i][j][k-1], st1[i+(1<<(k-1))][j][k-1], st1[i][j+(1<<(k-1))][k-1], st1[i+(1<<(k-1))][j+(1<<(k-1))][k-1]});
st2[i][j][k] = min({st2[i][j][k-1], st2[i+(1<<(k-1))][j][k-1], st2[i][j+(1<<(k-1))][k-1], st2[i+(1<<(k-1))][j+(1<<(k-1))][k-1]});
}
}
}
}
int main() {
cin >> n >> m >> k;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) cin >> a[i][j];
}
build_st();
int ans = 1e9;
for (int i = 1; i + k - 1 <= n; i++) {
for (int j = 1; j + k - 1 <= m; j++) {
ans = min(ans, ask_st(i, j));
}
}
cout << ans << endl;
}
例2 CF713D. Animals and Puzzle
题意
给定一个 $n \times m$ 的01矩阵。
给定 $t$ 个询问,每次询问一个矩阵 $x_1,y_1,x_2,y_2$,回答以 $(x_1,y_1)$ 为左上角,$(x_2,y_2)$ 为右下角的矩阵中,最大的全 $1$ 正方形的边长。
其中,$n,m \leq 1000, t \leq 10^6$。
题解
首先我们可以利用二分处理出一个数组 dp[i][j]
,代表以 $(i,j)$ 作为左上角,最大的全 $1$ 正方形的边长。
然后对于每次询问 $x_1,y_1,x_2,y_2$,二分一下答案,对于二分到的答案 $k$,我们只需要check一下 $(x_1,y_1,x_2-k+1,y_2-k+1)$ 这个矩阵中,dp
的最大值即可,用二维ST表即可实现 $O(1)$ 查询。
所以最后总复杂度是 $O(nm \log(n) \log(m))$。
• 记得将 ST 表开成 short,否则炸内存。
• 记得开快读,否则头都给T飞。
代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3+3;
int n, m, a[maxn][maxn], sum[maxn][maxn], dp[maxn][maxn];
int getsum(int i1, int j1, int i2, int j2) {
return sum[i2][j2] - sum[i2][j1-1] - sum[i1-1][j2] + sum[i1-1][j1-1];
}
short st[maxn][maxn][12][12], bin[maxn];
int ask_st(int i1, int j1, int i2, int j2) {
int k1 = bin[i2-i1+1], k2 = bin[j2-j1+1];
return max({st[i1][j1][k1][k2], st[i2-(1<<k1)+1][j1][k1][k2], st[i1][j2-(1<<k2)+1][k1][k2], st[i2-(1<<k1)+1][j2-(1<<k2)+1][k1][k2]});
}
void build_st() {
bin[1] = 0; bin[2] = 1;
for (int i = 3; i < maxn; i++) bin[i] = bin[i>>1] + 1;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) st[i][j][0][0] = dp[i][j];
}
for (int k1 = 1; k1 < 12; k1++) {
for (int i = 1; i + (1<<k1) - 1 <= n; i++) {
for (int j = 1; j <= m; j++) {
st[i][j][k1][0] = max(st[i][j][k1-1][0], st[i+(1<<(k1-1))][j][k1-1][0]);
}
}
}
for (int k2 = 1; k2 < 12; k2++) {
for (int i = 1; i <= n; i++) {
for (int j = 1; j + (1<<k2) - 1 <= m; j++) {
st[i][j][0][k2] = max(st[i][j][0][k2-1], st[i][j+(1<<(k2-1))][0][k2-1]);
}
}
}
for (int k1 = 1; k1 < 12; k1++) {
for (int k2 = 1; k2 < 12; k2++) {
for (int i = 1; i + (1<<k1) - 1 <= n; i++) {
for (int j = 1; j + (1<<k2) - 1 <= m; j++) {
st[i][j][k1][k2] = max({st[i][j][k1-1][k2-1], st[i+(1<<(k1-1))][j][k1-1][k2],
st[i][j+(1<<(k2-1))][k1][k2-1], st[i+(1<<(k1-1))][j+(1<<(k2-1))][k1-1][k2-1]});
}
}
}
}
}
int main() {
fastio;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) cin >> a[i][j];
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
sum[i][j] = sum[i-1][j] + sum[i][j-1] - sum[i-1][j-1] + a[i][j];
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
if (!a[i][j]) dp[i][j] = 0;
else {
int low = 1, high = min(n-i+1, m-j+1), res = 1;
while (low <= high) {
int mid = (low+high) >> 1;
if (getsum(i, j, i+mid-1, j+mid-1) == mid*mid) {
res = mid;
low = mid+1;
} else high = mid - 1;
}
dp[i][j] = res;
}
}
}
build_st();
int t; cin >> t;
while (t--) {
int i1, j1, i2, j2;
cin >> i1 >> j1 >> i2 >> j2;
if (!getsum(i1,j1,i2,j2)) {
cout << 0 << "\n";
continue;
}
int low = 1, high = min(i2-i1+1, j2-j1+1), res = 1;
while (low <= high) {
int k = (low + high) >> 1;
if (ask_st(i1,j1,i2-k+1,j2-k+1) >= k) {
res = k;
low = k+1;
} else high = k-1;
}
cout << res << "\n";
}
}