矩阵连乘优化
前言
从旭东的博客 看到一篇博文:矩阵连乘最优结合 动态规划求解,挺有意思的,这里做个转载【略改动】。
问题
矩阵乘法是这样的,比如\[ A_{ab} B_{bc} = C_{ac} \]
两个矩阵,一个a行,一个c列,行列乘法次数为a*c。一行乘以一列得到C中的一个元素,乘法次数为b,故矩阵乘法AB需要的乘法次数是a*c*b。
我们把b称为接口,那么矩阵连乘的次数是乘积的尺寸乘以中间的接口。中间的接口是矩阵高度,如果尽快能把长得高的矩阵通过乘法消化掉,这些大接口发生作用的机会就少,最终乘法次数就少了。
采用一维数组存储各矩阵高度。每次遍历找到最大值,和左边的矩阵相乘即可,直到最后只剩下一个矩阵。
输入参数是数组arr:存储各矩阵高度,最后一个元素为最后一个矩阵的列数,这个数组包含矩阵连乘表达式的所有信息。
loopTimes 是循环次数,循环次数为矩阵个数减1。
arrMaxId函数用于获取数组最大值索引,跳过第一个矩阵,因为第一个矩阵左边没有其他矩阵作为乘数。
由于最后要输出计算式,我们为每个矩阵设置一个名称,这个名称随着乘法的进行发生变化。最终会留下第一个矩阵,其名称就是最终运算式。
代码如下
int matrixMulTimes(vector<int> &arr) {
int maxId;
int mulTimes = 0;
int pre = 0, next=0;
string str="";
vector<string> matrixName(arr.size() - 1);
string mulLeftStr, mulRightStr;
int loops = 0;
int loopTimes = arr.size() - 2;
while (loops++ < loopTimes) {
maxId = arrMaxId(arr, 1, arr.size() - 2);
pre = maxId - 1;
next = maxId + 1;
while (arr[pre--] == -1);
while (arr[next++] == -1);
mulTimes += arr[++pre] * arr[maxId] * arr[--next];
arr[maxId] = -1;
mulLeftStr = matrixName[pre] == "" ? string(1, 'A' + pre) : matrixName[pre];
mulRightStr = matrixName[maxId] == "" ? string(1, 'A' + maxId) : matrixName[maxId];
matrixName[pre] = "(" + mulLeftStr +"*" + mulRightStr + ")";
}
cout << matrixName[0] << endl;
return mulTimes;
}
函数arrMaxId代码如下
int arrMaxId(vector<int> &arr, int begin, int end) {
if (arr.size() == 0 || arr.size()<end) {
return -1;
}
int maxId = begin;
for (int i = begin+1; i <= end; ++i) {
if (arr[i] > arr[maxId]) {
maxId = i;
}
}
return maxId;
}
上面只能是一个近似最优解,因为每次消去最高的矩阵,可能参与乘法的另一个矩阵也比较高,导致其存活更久更多地参与到运算中去,最后得不偿失。
如果非要得到最优解,可以运算并存储所有子式的运算量,自底向上,直到算出整个乘法算式。这个叫做动态规划,不过复杂度比较高。如果设置1000个矩阵相乘,动态规划可能算十几分钟都不一定有结果,但第一个算法几秒钟就能给出答案。
动态规划细节
每一个子式均由起点和跨度唯一决定。对于n个矩阵相乘,起点至多有n个,最小为0,最大为n-1。跨度至多n种,最小为0,最大为n-1。
怎么用之前的子式得到后面的子式呢?这个还挺麻烦的,得进行遍历,遍历子式中所有肯能的分割点。用两个n阶方阵分别存储每个子式的最少计算次数及分割点。
首先确定跨度,再确定起点,构成一个二级嵌套循环,通过起点的平移确定子式。子式确定后,再在内部嵌套一级循环,遍历子式的可能的分割点,并保存子式的最少计算次数及对应分割点。最后我们得到最大的子式,也就是连乘本身的最少运算次数及分割方案。
代码如下:
//根据记录的分割点,生成最后的矩阵相乘表达式
string make_result(vector<vector<int> > &points, int t1, int t2) {
if (t1 == t2)
return string(1, 'A' + t1 - 1);
int split = points[t1][t2];
return "(" + make_result(points, t1, split) + "*" + make_result(points, split + 1, t2) + ")";
}
int calculate_M(vector<int> &arr) {
int matrixNum = arr.size() - 1;
vector<vector<int> > num(matrixNum + 1, vector<int>(matrixNum + 1));
vector<vector<int> > points(matrixNum + 1, vector<int>(matrixNum + 1));
int span;
int start;
int end;
int spiltPoint;
int mulTimes;
int rows, columns, interfaces;
for (span = 1; span < arr.size() - 1; span++) {
for (start = 1; start + span < arr.size(); start++) {
end = start + span;
num[start][end] = INT_MAX;
for (spiltPoint = start; spiltPoint < end; spiltPoint++) {
rows = arr[start - 1];
columns = arr[end];
interfaces = arr[spiltPoint];
mulTimes = num[start][spiltPoint] + num[spiltPoint + 1][end] + rows * interfaces * columns;
if (mulTimes < num[start][end]) {
points[start][end] = spiltPoint;
num[start][end] = mulTimes;
}
}
}
}
cout << make_result(points, 1, matrixNum) << "\t最少乘法次数为:" << num[1][matrixNum] << endl;
return 0;
}
代码中用到的一些知识
C++提供模版类string,其中一个构造方法可将字符转化为字符串。如 string(1, 'A'+1),第一个参数是源字符延拓次数,这个构造函数将‘B’转化为"B"。