参考poj1390

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 105;
int n;
char s[maxn];
int a[maxn];
int c[maxn], len[maxn], cnt;
ll f[maxn][maxn], cost[maxn];
ll dp[maxn][maxn][maxn];

ll calf(int i, int j) {
if (f[i][j] != -1) return f[i][j];
if (j == 1) return a[i];
if (j == i) return 1LL * i * a[1];
for (int k = 1; k <= i - j + 1; k++) {
f[i][j] = max(f[i][j], calf(i-k, j-1) + cost[k]);
}
return f[i][j];
}

ll caldp(int i, int j, int k) {
if (dp[i][j][k] != -1) return dp[i][j][k];
if (i == j) return cost[len[j] + k];
ll ans = caldp(i, j-1, 0) + cost[len[j] + k];
for (int t = i; t < j; t++) {
if (c[t] == c[j]) {
ans = max(ans, caldp(i, t, len[j] + k) + caldp(t+1, j-1, 0));
}
}
return dp[i][j][k] = ans;
}

int main() {
scanf("%d", &n);
scanf("%s", s);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
cnt = 1;
c[cnt] = s[0] - '0';
len[cnt] = 1;
for (int i = 1; s[i]; i++) {
if (s[i] == s[i-1]) {
len[cnt]++;
}
else {
cnt++;
c[cnt] = s[i] - '0';
len[cnt] = 1;
}
}
memset(f, -1, sizeof(f));
for (int i = 1; i <= n; i++) {
cost[i] = a[i];
for (int j = 1; j <= i; j++) {
cost[i] = max(cost[i], calf(i, j));
}
}
memset(dp, -1, sizeof(dp));
printf("%lld\n", caldp(1, cnt, 0));
return 0;
}