介绍

高斯消元是矩阵操作里最基础的一个,主要用于解形如 $a_1x_1 + a_2x_2 + a_3x_3 = b_1$ 之类的线性方程组。

步骤

  1. 按 column 进行遍历
  2. 遍历到第 $k$ 个 column 时,在这个column中,寻找一个 $maxrow$ 使得 $A[maxrow][k]$ 最大。
  3. 将第 $k$ 行 与 第 $maxrow$ 行进行交换。
  4. 交换后,以第 $k$ 行作为pivot,减掉其他所有行,消去第 $k$ 列其他行上的数字。
  5. 最后会得到一个 diagonal matrix,除以对应系数即可得到 $\vec{b}$

复杂度:$O(n^3)$

先上板子:

模版 P3389 【模板】高斯消元法

题意

给定一个线性方程组,进行求解。

输入一个 $n \times (n+1)$ 的矩阵,每一行为 $a_1, a_2, …, a_n, b_k$,代表一组方程。

#include <bits/stdc++.h>
 
using namespace std;
const double eps = (double)1e-6;

double A[maxn][maxn];
int n; 

void solve() {
    for (int k = 1; k <= n; k++) {
        int maxrow = k;  // 当前行初始情况下为:当前列对应的行数
        for (int i = k+1; i <= n; i++) {  // 从当前行往下找
            if (abs(A[i][k]) > abs(A[maxrow][k])) maxrow = i;
        }
        for (int i = 1; i <= n+1; i++) swap(A[k][i], A[maxrow][i]);  //交换两行, 保证A[k][k]最大
        if (abs(A[k][k]) < eps) {
            cout << "Infinite solutions." << endl;  // 出现一行全是 0 的情况,一般说明有无数个解
            return;
        } 
        for (int i = 1; i <= n; i++) {
            if (i == k) continue;
            double m = A[i][k] / A[k][k];
            for (int j = 1; j <= n+1; j++) {  //更新上下的行
                A[i][j] -= (m * A[k][j]); 
            }
        }
    }

    for (int i = 1; i <= n; i++) {
        A[i][n+1] /= A[i][i];  // 最后除以斜线上的系数
        A[i][i] = 1.0;
        printf("%.2f\n", A[i][n+1]);
    }
}

int main() {
    cin >> n;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n+1; j++) {
            cin >> A[i][j];  // 输入矩阵,j=n+1 代表 b
        }
    }
    solve();
}

例题

例1 洛谷P2447 [SDOI2010]外星千足虫

题意

有 $n$ 个虫子,每只虫子来自地球或者火星。来自地球的虫子有偶数个脚,火星的虫子有奇数个脚。

现在给定 $m$ 个方程,代表已知哪些虫子的脚数量之和为奇数或者偶数。

求每个虫子的归属地(地球或者火星)。

如果有解,输出最小的 $k \leq n$ 使得仅用前 $k$ 个方程即可得到结果。

如果无解,则输出 Cannot Determine。

其中,$1 \leq n \leq 10^3, 1 \leq m \leq 2 \times 10^3$。

题解

高斯消元即可。现在有几个问题:

Q1. 如果有解,如何求出 $k$?

A1. 注意到高斯消元的第二步是寻找一个 $maxrow$。然而这个题中,矩阵中的数要么为 $0$,要么为 $1$。所以只要找到 $1$,就不需要再往下找了。

所以在寻找 $1$ 的过程中记录一下用到的 row number 的最大值即可,一旦找到就立刻break。


Q2. 高斯消元是 $O(n^3)$ 的,复杂度不对?

A2. 再次注意到这个矩阵仅由 $0,1$ 组成,可以用 bitset 进行优化。优化幅度为 $\frac{1}{w}$,其中 $w=32$。

所以复杂度为 $O(\frac{n^2m}{w})$,足以通过本题。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e3+2;
const int maxm = 7e5+5;

bitset<maxn> A[maxn+maxn];
int ans;
int n,m; 

void solve() {
    for (int k = n; k >= 1; k--) {
        int maxrow = n-k+1;
        for (int i = n-k+1; i <= m; i++) {
            ans = max(ans, i);
            if (A[i].test(k)) {
                maxrow = i;
                break;
            }
        }
        swap(A[n-k+1], A[maxrow]);
        if (!A[n-k+1].test(k)) {
            cout << "Cannot Determine\n";
            return;
        }
        for (int i = 1; i <= m; i++) {
            if (i == n-k+1) continue;
            if (A[i].test(k))
                A[i] ^= A[n-k+1];
        }
    }
    cout << ans << endl;
    for (int i = 1; i <= n; i++) {
        if (A[i].test(0)) cout << "?y7M#\n";
        else cout << "Earth\n";
    }
}


int main() {
    cin >> n >> m;
    if (m < n) {
        cout << "Cannot Determine\n";
        return 0;
    }
    ans = 1;

    for (int i = 1; i <= m; i++) {
        string s1,s2; cin >> s1 >> s2;
        A[i] = bitset<maxn>(s1 + s2);
    }

    solve();
}

例2 洛谷P4035 [JSOI2008]球形空间产生器

题意

在 $n$ 维空间中,有一个球体。现在已知该球面上 $(n+1)$ 个点的坐标,求出球心坐标?

其中,$1 \leq n \leq 10$,数据保证有唯一解。

题解

设球心坐标为 $(b_1,b_2,…,b_n)$,那么就可以列出 $(n+1)$ 个方程:

$(b_1 - a_{11})^2 + (b_2 - a_{12})^2 + … + (b_n - a_{1n})^2 = R^2$ …

但是这并不是线性方程组。

为了使其成为线性方程组,将 $(n+1)$ 个方程进行相减,得到

  1. 方程 $2$ 减去方程 $1$
  2. 方程 $3$ 减去方程 $1$ …

即可得到 $n$ 个线性方程。

然后用高斯消元即可。

代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 12;

double A[maxn][maxn], arr[maxn][maxn];
int n; 

void solve() {
    for (int k = 1; k <= n; k++) {
        int maxrow = k;
        for (int i = k+1; i <= n; i++) {
            if (abs(A[i][k]) > abs(A[maxrow][k])) maxrow = i;
        }
        for (int i = 1; i <= n+1; i++) swap(A[k][i], A[maxrow][i]);  //交换两行, 保证A[k][k]最大
        if (abs(A[k][k]) < eps) {
            cout << "No Solution" << endl;
            return;
        } 
        for (int i = 1; i <= n; i++) {
            if (i == k) continue;
            double m = A[i][k] / A[k][k];
            for (int j = 1; j <= n+1; j++) {  //更新上下的行
                A[i][j] -= (m * A[k][j]); 
            }
        }
    }

    for (int i = 1; i <= n; i++) {
        A[i][n+1] /= A[i][i];  // 最后除以斜线上的系数
        A[i][i] = 1.0;
        printf("%.3f ", A[i][n+1]);
    }
}

double sq(double a) {
    return a*a;
}

int main() {
    cin >> n;
    for (int i = 1; i <= n+1; i++) {
        for (int j = 1; j <= n; j++) {
            cin >> arr[i][j];
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            A[i][j] = arr[i+1][j] - arr[i][j];
        }

        for (int k = 1; k <= n; k++) {
            A[i][n+1] += sq(arr[i+1][k]) - sq(arr[i][k]); 
        }
        A[i][n+1] /= 2;
    }
    solve();
}