【模板】矩阵求逆

Luogu P4783

题目描述

求一个 \(N\times N\) 的矩阵的逆矩阵。答案对 \({10}^9+7\)

输入格式

第一行有一个整数 \(N\),代表矩阵的大小;

接下来 \(N\) 行,每行 \(N\) 个整数,其中第 \(i\) 行第 \(j\) 列的数代表矩阵中的元素 \(a_{i j}\)。

输出格式

若矩阵可逆,则输出 \(N\) 行,每行 \(N\) 个整数,其中第 \(i\) 行第 \(j\) 列的数代表逆矩阵中的元素 \(b_{i j}\),答案对 \({10}^9+7\)

否则只输出一行 No Solution

样例 #1

样例输入 #1

3
1 2 8
2 5 6
5 1 2

样例输出 #1

718750005 718750005 968750007
171875001 671875005 296875002
117187501 867187506 429687503

样例 #2

样例输入 #2

3
3 2 4
7 2 9
2 4 3

样例输出 #2

No Solution

提示

对 \(30 \%\) 的数据有 \(N\le 100\);
对 \(100 \%\) 的数据有 \(N\le 400\),所有 \(0 \le a_{i j} < {10}^9 + 7\)。

Solution

假设有一个矩阵 \(A\),如果想要计算除法,会发现我们没有定义矩阵的除法,不过会联想到数的除法,所以就想到了逆元这种方法。

矩阵的逆 \(A^{-1}\) 定义为满足 \(A^{-1}\times A=A\times A^{-1}=I\) 的矩阵(\(I\) 表示单位矩阵,即主对角线上的值都为 \(1\)

假设矩阵 \(A\) 能够通过进行一系列的矩阵乘法来变成单位矩阵 \(I\) ,那么这过程中的所有矩阵的积就是需要找到的逆元。假设这过程中用到的矩阵为 \(p_1,p_2,p_3\cdots,p_k\),那么就有 $$p_k\times p_{k-1}\times \cdots \times p_2 \times p_1\times A=I$$

也就是说需要找到合适的矩阵数列 \(p\) 来求矩阵逆。需要将矩阵逐步变成一个主对角线为 \(1\)

将第二行乘上一个数 \(c\):

\[\begin{bmatrix} 1&0&0&\cdots&0\\ 0&c&0&\cdots&0\\ 0&0&1&\cdots&0\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 0&0&0&\cdots&1 \end{bmatrix} \]

倍加:

\[\begin{bmatrix} 1&0&0&\cdots&0\\ 0&1&0&\cdots&0\\ 0&c&1&\cdots&0\\ \vdots&\vdots&\vdots&\ddots&\vdots\\ 0&0&0&\cdots&1 \end{bmatrix} \]

所以就有一种做法:将原矩阵 \(A\) 和一个单位矩阵 \(I\) 左右拼接组成一个新的 \(n\times 2n\) 的矩阵,然后对左侧的矩阵 \(A\) 进行高斯消元,最后右边的单位矩阵 \(I\) 自然就乘上了数列 \(p\),所以如果将初始矩阵表示为 \((A,I)\),那么最终矩阵就是 \((I,A^{-1})\)

因为是模 \(10^9+7\),所以除法需要用逆元,逆元用费马小定理搭配快速幂即可。整个代码的总时间复杂度为 \(\mathcal O(n^3\log p)\),其中 \(n^3\) 来自高斯消元,\(\log p\)

Code

#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof a)
#define int long long
using namespace std;
template<typename T> void read(T &k)
{
	k=0;T flag=1;char b=getchar();
	while (!isdigit(b)) {flag=(b=='-')?-1:1;b=getchar();}
	while (isdigit(b)) {k=k*10+b-48;b=getchar();}
	k*=flag;
}
const int MOD=1e9+7;
int Fpow(int a,int b)
{
	int base=a%MOD,res=1;
	while (b)
	{
		if (b&1) res=res*base%MOD;
		base=base*base%MOD,b>>=1;
	}
	return res;
}
const int _SIZE=4e2;
int a[_SIZE+5][(_SIZE<<1)+5];
int n;
void Calc()
{
	for (int i=1;i<=n;i++)
	{
		int r=i;
		for (int j=i+1;j<=n;j++)
			if (a[j][i]>a[r][i]) r=j;
		if (i!=r) swap(a[i],a[r]);
		if (!a[i][i]) {puts("No Solution");return;}
		int inv=Fpow(a[i][i],MOD-2);
		for (int k=1;k<=n;k++)
		{
			if (i==k) continue;
			int p=a[k][i]*inv%MOD;
			for (int j=i;j<=(n<<1);j++)
				a[k][j]=((a[k][j]-a[i][j]*p)%MOD+MOD)%MOD;	
		}
		for (int j=1;j<=(n<<1);j++)
			a[i][j]=(a[i][j]*inv)%MOD;
	}
	for (int i=1;i<=n;i++) {for (int j=n+1;j<=(n<<1);j++) printf("%lld ",a[i][j]); puts("");}
}
signed main()
{
	read(n);
	for (int i=1;i<=n;i++)
		for (int j=1;j<=n;j++) read(a[i][j]),a[i][i+n]=1;
	Calc();
	return 0;
}