给定n个矩阵{A1,A2,…,An},其中Ai与Ai+1是可乘的,i=1,2,…,n-1。考察这n个矩阵的连乘积A1A2…An。由于矩阵乘法满足结合律,故计算矩阵的连乘积可以有许多不同的计算次序,这种计算次序可以用加括号的方式来确定。若一个矩阵连乘积的计算次序完全确定,则可以依此次序反复调用2个矩阵相乘的标准算法(有改进的方法,这里不考虑)计算出矩阵连乘积。若A是一个p×q矩阵,B是一个q×r矩阵,则计算其乘积C=AB的标准算法中,需要进行pqr次数乘。
矩阵连乘积的计算次序不同,计算量也不同,举例如下:
先考察3个矩阵{A1,A2,A3}连乘,设这三个矩阵的维数分别为10×100,100×5,5×50。若按((A1A2)A3)方式需要的数乘次数为10×100×5+10×5×50=7500,若按(A1(A2A3))方式需要的数乘次数为100×5×50+10×100×50=75000。
下面使用动态规划法找出矩阵连乘积的最优计算次序。
I: 设矩阵连乘积AiAi+1…Aj简记为A[i:j],设最优计算次序在Ak和Ak+1之间断开,则加括号方式为:((AiAi+1…Ak)(Ak+1…Aj))。则依照这个次序,先计算A[i:k]和A[K+1:j]然后再将计算结果相乘,计算量是:A[i:k]的计算量加上A[K+1:j]的计算量再加上它们相乘的计算量。问题的一个关键是:计算A[i:j]的最优次序所包含的两个子过程(计算A[i:k]和A[K+1:j])也是最优次序。
II: 设计算A[i:j]所需的最少数乘次数为m[i][j]。
- i=j时为单一矩阵,则m[i][i]=0;
- i<j时,设最优计算次序在Ak和Ak+1之间断开,则m[i][j]=m[i][k]+m[k+1][j]+pi-1pkpj,其中p表示数组的维数。k此时并未确定,需要从i到j-1遍历以寻找一个最小的m[i][j]。我们把这个最小的k放在s[i][j]。其中s[i][j]记录了断开的位置,即计算A[i:j]的加括号方式为:(A[i:s[i][j]])*(A[s[i][j]+1:j])
其动态规划递归公式如下:
其中m[i][j]就是Ai...Aj这 j-i+1 个矩阵连乘需要的最少的乘法次数。
1: /**
2: * 下面是矩阵连乘问题的动态规划算法
3: * 假设有6个矩阵:
4: * A1 A2 A3 A4 A5 A6
5: * 30*35 35*15 15*5 5*10 10*20 20*25 则matrixChain为
6: * {30, 35, 15, 5, 10, 20, 25} 结果为
7: * ((A1 * (A1 * A2)) * ((A4 * A5) * A6) )
8: */
9: public class MatrixMultiply {
10: // Traceback打印A[i:j]的加括号方式
11: public static void traceback(int[][] s, int i, int j) {
12: //s[i][j]记录了断开的位置,即计算A[i:j]的加括号方式为:
13: //(A[i:s[i][j]])*(A[s[i][j]+1:j])
14: if (i == j)
15: return;
16: traceback(s, i, s[i][j]);//递归打印A[i:s[i][j]]的加括号方式
17: traceback(s, s[i][j] + 1, j);//递归打印A[s[i][j]+1:j]的加括号方式
18: System.out.println("Multiply A(" + i + "," + s[i][j] + ")and A("
19: + (s[i][j] + 1) + "," + j+")");
20:
21: }
22:
23: public static void matrixChain(int[] p, int[][] m, int[][] s) {
24: int n = p.length - 1;
25: for (int i = 0; i <= n; i++)
26: m[i][i] = 0;
27: // 上三角矩阵
28: for (int r = 2; r <= n; r++)//r为连乘矩阵的个数
29: for (int i = 1; i <= n - r + 1; i++) {//i就是连续r个矩阵的第一个
30: int j = i + r - 1;//j就是连续r个矩阵的最后一个
31: m[i][j] = 99999999;//m[i + 1][j] + p[i - 1] * p[i] * p[j];
32: s[i][j] = i;
33: //求m[i][j],m[i][j]就是Ai...Aj这 j-i+1 个矩阵连乘需要的最少的乘法次数
34: for (int k = i; k < j; k++) {
35: //A[i]的维数为:p[i - 1] * p[k]
36: int t = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
37: if (t < m[i][j]) {//寻找最小值
38: m[i][j] = t;
39: s[i][j] = k;//记录划分标记
40: }
41: }
42: }
43:
44:
45: //n个矩阵连乘的最少相乘次数
46: System.out.println("给定的"+n+"个矩阵连乘的最少相乘次数:"+m[1][n]);
47: printM(m);
48: printS(s);
49: }
50:
51: private static void printM(int[][] m) {
52: // TODO Auto-generated method stub
53: for (int i = 1; i < 7; i++) {
54: for (int j = 1; j < 7; j++) {
55: System.out.print(m[i][j] + "\t");
56: }
57: System.out.println();
58: }
59: }
60:
61: private static void printS(int[][] s) {
62: // TODO Auto-generated method stub
63: for (int i = 1; i < 7; i++) {
64: for (int j = 1; j < 7; j++) {
65: System.out.print(s[i][j] + "\t");
66: }
67: System.out.println();
68: }
69: }
70:
71: public static void main(String[] args) throws Exception {
72: int[][] m = new int[7][7];
73: int[][] s = new int[7][7];
74: int[] p = new int[] { 30, 35, 15, 5, 10, 20, 25 };
75: matrixChain(p, m, s);
76: //((A1(A2A3))((A4A5)A6))
77: traceback(s, 1, 6);
78:
79: }
80: }