/*
将连续的若干矩阵相乘,若从左向右依次相乘,那么效率比较低。因为矩阵乘法满足结合律,所以
可以改变矩阵相乘的次序。算法将第0个到第i个矩阵从k个位置断开,利用递归找到最优的结合次序
*/
#include<iostream>

using namespace std;

const int n=6;
const int Max=1000000;
//m[i][j]表示第i个矩阵和第j个矩阵相乘的乘法总次数
//s[i][j]表示第i个矩阵和第j个矩阵相乘应该从第s[i][j]个位置断开
int M(int a[][2],int m[][n],int s[][n],int i,int j)
{
int k,min;
int x;
if(i==j)
{
m[i][j]=0;
s[i][j]=0;
return 0;
}
else if(i<j)
{
min=Max;
for(k=i;k<j;k++)
{
x=M(a,m,s,i,k)+M(a,m,s,k+1,j)+a[i][0]*a[k+1][0]*a[j][1];
if(x<min)
{
min=x;
s[i][j]=k;
}
}
m[i][j]=min;
return min;
}
}

//显示结合次序
void PrintSequence(int s[][n],int i,int j)
{
int k;
if(i==j)
{
cout<<"A"<<i;
}
else if(i<j)
{
k=s[i][j];
cout<<"(";
PrintSequence(s,i,k);
PrintSequence(s,k+1,j);
cout<<")";
}
}

//矩阵结构体
typedef struct
{
int m,n;
int **array;
}Matrix;

//初始化,a[n][2]存放n个矩阵的行数和列数,p[n]存放这n个矩阵
void init(int a[n][2],Matrix p[n])
{
int i,j,k;
for(i=0;i<n;i++)
{
cout<<"第"<<i<<"个矩阵的行:";
cin>>p[i].m;
cout<<"第"<<i<<"个矩阵的列:";
cin>>p[i].n;
a[i][0]=p[i].m;
a[i][1]=p[i].n;
p[i].array=new int*[p[i].n];
for(j=0;j<p[i].m;j++)
{
p[i].array[j]=new int[p[i].n];
for(k=0;k<p[i].n;k++)
{
cout<<"第"<<j<<","<<k<<"个值:";
cin>>p[i].array[j][k];
}
}
}
}

void printMatrix(Matrix p[n])
{
int i,j,k;
for(i=0;i<n;i++)
{
cout<<"第"<<i<<"个矩阵的行是:"<<p[i].m<<endl;
cout<<"第"<<i<<"个矩阵的列是:"<<p[i].n<<endl;
for(j=0;j<p[i].m;j++)
{
for(k=0;k<p[i].n;k++)
{
cout<<p[i].array[j][k]<<" ";
}
cout<<endl;
}
}
}

void printMatrix(const Matrix &m)
{
int i,j;
cout<<"结果矩阵行是:"<<m.m<<endl;
cout<<"结果矩阵列是:"<<m.n<<endl;
for(i=0;i<m.m;i++)
{
for(j=0;j<m.n;j++)
{
cout<<m.array[i][j]<<" ";
}
cout<<endl;
}
}

//将a,b两个矩阵相乘放入矩阵c中
void Multiple(const Matrix& a,const Matrix& b,Matrix &c)
{
if(a.n!=b.m)
{
cerr<<"矩阵类型不匹配!"<<endl;
return;
}
int i,j,k;
int s;
c.m=a.m;
c.n=b.n;
c.array=new int*[b.n];
for(i=0;i<a.m;i++) c.array[i]=new int[b.n];
for(i=0;i<a.m;i++)
{
for(j=0;j<b.n;j++)
{
s=0;
for(k=0;k<a.n;k++)
{
s+=a.array[i][k]*b.array[k][j];
}
c.array[i][j]=s;
}
}
}

//递归求p[n]矩阵乘积
void Multiple(Matrix p[n],int s[n][n],int i,int j,Matrix &result)
{
int k;
if(i==j)
{
result=p[i];
return;
}
else if(i<j)
{
k=s[i][j];
Matrix L,R;
Multiple(p,s,i,k,L);
Multiple(p,s,k+1,j,R);
Multiple(L,R,result);
}
}

void main()
{
int a[n][2];
int m[n][n],s[n][n];
Matrix p[n];
init(a,p);
int i,j;
cout<<M(a,m,s,0,n-1)<<endl;
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
if(i>j) m[i][j]=s[i][j]=0;
cout<<m[i][j]<<" ";
}
cout<<endl;
}
cout<<"-------------"<<endl;
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
cout<<s[i][j]<<" ";
}
cout<<endl;
}
PrintSequence(s,0,n-1);
cout<<endl;
printMatrix(p);
Matrix result;
Multiple(p,s,0,n-1,result);
printMatrix(result);
}