算法一段时间不看以后就会变得很陌生,今天来重温一下之前写的矩阵连乘算法。
感谢孙海燕老师这学期对我的帮助,让我对枯燥而又复杂的算法学习增加了信心。
一个好的博客不光是给现在的自己看的,更是给大家和一段时间以后的自己看的。希望我这次能够写的简单易懂。
首先
A1 30*35
A2 35*15
A3 15*5
A4 5*10
A5 10*20
A6 20*25
这几个矩阵通过不同的结合方式相乘,最后的总共运算量是不同的,我们需要找到一种结合方式使得矩阵的连乘积最小。
这个算法可以使用枚举法一个个相乘,但是那样子复杂度太高。
我们采用动态规划的算法,动态规划的算法是自底向上的。我们先算1个矩阵相乘(答案为0)再算两个矩阵相乘,一直到n个矩阵相乘。
我们使用一个二维数组m[][],表示矩阵连乘的信息。例如m[1][3]表示第一个矩阵到第三个矩阵的连乘。
因为一个矩阵自己出现时连乘积为0,则m[1][1],m[2][2],m[3][3].....都为0。即数组对角线上的值为0。
因此有代码:
for(int i=1;i<=n;i++)
{
m[i][i]=0;//对角线上的值均为0
}
这里可以发现我并没有使用数组的0下标,便于我们对问题的理解。
说完了一个矩阵的情况,我们开始正式计算多个矩阵相乘。
1过来以后便是2,再从2到n,因此定义变量r表示相乘的矩阵个数,但是最多只有n个矩阵因此r<=n。
定义一个变量i表示起始位置(这里可能会有不理解)我们2个矩阵相乘时,可以是m[1][2]、可以是m[2][3]、m[3][4]。。。所以说我们需要给i定义一个范围,我们知道,i+r-1<=n(起始点加上个数减去1小于等于总共的个数);所以,i<=n-r+1;
ok,我们知道了连乘矩阵的个数和起始点,我们还差一个结束点。想必大家也都知道,知道起点和距离就能知道终点,因此设置变量j为矩阵连乘的结束点。则j=i+r-1;大家有没有发现一个问题,n>=i+r-1,j=i+r-1。因此j<=n;这也符合了我们运算时数组不会越界的规则不是吗,现在看起来一切都很好理解。
for(int r=2;r<=n;r++)//相乘的矩阵的个数
{
for(int i=1;i<=n-r+1;i++)//i+r-1最大为n,当r=2时为2个矩阵相乘,i从1到5.当r=6时,i只能为1
{
int j=i+r-1;
接下来就是核心代码了,想要把这一点说清楚有点不太容易。
我们默认从i处分开时取得最小值,例如m[1][3],我们默认最小值为先算m[2][3]然后再与第一个矩阵相乘,m[2][3]的值已经在r=2的时候算过了。但其实还有一种算法是先算m[1][2],再与第三个矩阵相乘。因此我么再加上一个循环,将分割点向右挪动。为什么说是向右挪动呢,例如A1*A2*A3,我们默认最小是A1*(A2*A3),但是还有一种为A1*A2*(A3)。让我们看看四个矩阵相乘的时候。A1*A2*A3*A4,我们默认的最小值为A1*(A2*A3*A4),但还可以是A1*A2*(A3*A4)和A1*A2*A3*(A4)。
好了洋洋洒洒说了这么多。直接上代码可能更有效。
for(int r=2;r<=n;r++)//相乘的矩阵的个数
{
for(int i=1;i<=n-r+1;i++)//i+r-1最大为n,当r=2时为2个矩阵相乘,i从1到5.当r=6时,i只能为1
{
int j=i+r-1;
m[i][j]=m[i+1][j]+p[i-1]*p[i]*p[j];//从第i处断开
s[i][j]=i;
for(int k=i+1;k<j;k++)
{
int t=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j];//从第i个点以后逐个断开
if(t<m[i][j])
{
m[i][j]=t;
s[i][j]=k;
}
}
}
}
System.out.println(m[1][6]);
}
我知道可能有人会看不懂
m[i][j]=m[i+1][j]+p[i-1]*p[i]*p[j];
int t=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j];
这两个式子其实是一样的,都是将矩阵分成两半相乘,最后合起来。
m[i][j]=m[i][i]+m[i+1][j]+p[i-1]*p[i]*p[j];//从第i处断开
int t=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j];//从第i个点以后逐个断开
这样子是不是好理解多了?
只是从第i个矩阵后面断开以后直接m[i][i]=0,省略掉不写了而已。
其实也很好理解,无非就是起始点x断开点x结束点的值,这正是两矩阵相乘的连乘积。
至于数组s是用来记录从哪一点断开时连乘积最小的。
接下来是整体的代码:
package com.sf.work;
import java.util.Scanner;
public class Jzlc {
public static void lc(int m[][],int n,int s[][],int p[])
{
for(int i=1;i<=n;i++)
{
m[i][i]=0;//对角线上的值均为0
}
for(int r=2;r<=n;r++)//相乘的矩阵的个数
{
for(int i=1;i<=n-r+1;i++)//i+r-1最大为n,当r=2时为2个矩阵相乘,i从1到5.当r=6时,i只能为1
{
int j=i+r-1;
m[i][j]=m[i+1][j]+p[i-1]*p[i]*p[j];//从第i处断开
s[i][j]=i;
for(int k=i+1;k<j;k++)
{
int t=m[i][k]+m[k+1][j]+p[i-1]*p[k]*p[j];//从第i个点以后逐个断开
if(t<m[i][j])
{
m[i][j]=t;
s[i][j]=k;
}
}
}
}
System.out.println(m[1][6]);
}
public static void traceback(int s[][],int i,int j)
{
if(i==j)System.out.print("A"+i);
else
{
System.out.print("(");
traceback(s,i,s[i][j]);
traceback(s,s[i][j]+1,j);
System.out.print(")");
}
}
public static void main(String[] args) {
int n;
Scanner rd=new Scanner(System.in);
System.out.println("请输入矩阵的个数:");
n=rd.nextInt();
int p[]=new int[n+1];
System.out.println("请输入矩阵具体信息:");
for(int i=0;i<n+1;i++)
{
p[i]=rd.nextInt();
}
int m[][]=new int[n+1][n+1];
int s[][]=new int[n+1][n+1];
System.out.println("最小值为:");lc(m,n,s,p);
System.out.println("组合方式为:");
traceback(s,1,n);
}
}
运行结果
我一直认为能够说的清的都是真的掌握了的,哪怕有一点的不明白都不会有理有据的讲出来。感谢你能看到这里,希望自己越来越努力。