题目:

帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的n \times mn×m的矩阵,矩阵中的每个元素ai,ja_{i,j}ai,j​

均为非负整数。游戏规则如下:

  1. 每次取数时须从每行各取走一个元素,共nnn个。经过mmm次后取完矩阵内所有元素;
  2. 每次取走的各个元素只能是该元素所在行的行首或行尾;
  3. 每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值×2i\times 2^i×2i ,其中iii表示第iii次取数(从1开始编号);
  4. 游戏结束总得分为mmm次取数得分之和。

帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。


思路:

因为每一行的游戏是互不影响的,所以我们可以一行一行来做。

对于每一行,设f[i][j]f[i][j]f[i][j]表示取完i∼ji\sim ji∼j的最大得分,那么显然有

f[i][j]=max(f[i−1][j]+a[i]×2n−(j−i),f[i][j−1]+a[j]×2n−(j−i))f[i][j]=max(f[i-1][j]+a[i]\times 2^{n-(j-i)},f[i][j-1]+a[j]\times 2^{n-(j-i)})f[i][j]=max(f[i−1][j]+a[i]×2n−(j−i),f[i][j−1]+a[j]×2n−(j−i))

高精度转移即可。


代码:

#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N=90,MAXN=10;
int n,m,ans[MAXN+1],f[N][N][MAXN+1],power[N][MAXN+1],a[N][MAXN+1],p[MAXN+1],q[MAXN+1];

void mul(int c[MAXN+1],int a[MAXN+1],int b[MAXN+1])
{
for (register int i=MAXN;i>=1;i--)
{
int t=0;
for (register int j=MAXN;j>=1;j--)
{
c[i+j-MAXN]+=a[i]*b[j]+t;
t=c[i+j-MAXN]/10000;
c[i+j-MAXN]%=10000;
}
}
}

void add(int a[MAXN+1],int b[MAXN+1])
{
int t=0;
for (register int i=MAXN;i>=1;i--)
{
a[i]+=b[i]+t;
t=a[i]/10000;
a[i]%=10000;
}
}

bool check(int a[MAXN+1],int b[MAXN+1])
{
for (register int i=1;i<=MAXN;i++)
if (a[i]>b[i]) return 1;
else if (a[i]<b[i]) return 0;
return 1;
}

int main()
{
// freopen("testdata.in","r",stdin);
scanf("%d%d",&m,&n);
power[0][MAXN]=1; power[1][MAXN]=2;
for (register int i=2;i<=n;i++)
mul(power[i],power[i-1],power[1]);
while (m--)
{
memset(a,0,sizeof(a));
memset(f,0,sizeof(f));
for (register int i=1,x;i<=n;i++)
{
scanf("%d",&x);
for (register int j=MAXN;j>=1;j--,x/=10000)
a[i][j]=x%10000;
mul(f[i][i],a[i],power[n]);
}
for (register int i=n;i>=1;i--)
for (register int j=i+1;j<=n;j++)
{
memset(p,0,sizeof(p));
memset(q,0,sizeof(q));
mul(p,a[i],power[n-(j-i)]);
mul(q,a[j],power[n-(j-i)]);
add(p,f[i+1][j]);
add(q,f[i][j-1]);
if (check(p,q)) memcpy(f[i][j],p,sizeof(p));
else memcpy(f[i][j],q,sizeof(q));
}
add(ans,f[1][n]);
}
int i=1;
while (!ans[i] && i<=MAXN) i++;
if (i>MAXN) return !printf("0");
printf("%d",ans[i]);
for (i++;i<=MAXN;i++)
{
if (ans[i]<1000) putchar(48);
if (ans[i]<100) putchar(48);
if (ans[i]<10) putchar(48);
printf("%d",ans[i]);
}
return 0;
}