问题相当于求:串 a a a的某一段和串 b b b的某一段的归并序列和。
如果只有两段子串 a , b a,b a,b,求他们的归并序列数量,怎么做??
定义 f [ i ] [ j ] [ 0 / 1 ] f[i][j][0/1] f[i][j][0/1]表示取出串 a a a最近一次是 a i a_i ai,取出的串 b b b最近一次是 b j b_j bj,且当前取出的是 a / b a/b a/b串的字母
转移方程非常好写。初始化 f [ 1 ] [ 0 ] [ 0 ] = f [ 0 ] [ 1 ] [ 1 ] = 1 f[1][0][0]=f[0][1][1]=1 f[1][0][0]=f[0][1][1]=1即可
转移方程
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
if( a[i]!=a[i-1] ) f[i][j][0] += f[i-1][j][0];
if( a[i]!=b[j] ) f[i][j][0] += f[i-1][j][1];
if( b[j]!=b[j-1] ) f[i][j][1] += f[i][j-1][1];
if( b[j]!=a[i] ) f[i][j][1] += f[i][j-1][0];
f[i][j][0] %= mod; f[i][j][1] %= mod;
}
然后现在问题升级,求任意段的归并序列和
那我们可以把这两段看成多起点 d p dp dp
之前是只能从串 a a a的一号位置,串 b b b的一号位置开始
现在可以从任意位置,那是不是说初始化所有的
f [ i ] [ j ] [ 0 ] = f [ i ] [ j ] [ 1 ] f[i][j][0]=f[i][j][1] f[i][j][0]=f[i][j][1]当且仅当 a i ! = b j a_i!=b_j ai!=bj
然后我就在这里卡了一个小时
这样算答案会变小,因为我们方程的定义是取出串 a a a最近的是 a i a_i ai,取出串 b b b最近的是 b j b_j bj
虽然现在 a i ! = b j a_i!=b_j ai!=bj,但是我可以取出 a i , a i + 1 a_i,a_{i+1} ai,ai+1再取出 b j b_j bj,此时 a i + 1 ! = b j a_{i+1}!=b_j ai+1!=bj
这样也是符合这种定义的,然而我们根本没有算进去!!!
所以我们改变 f f f数组的初始化
定义 f [ i ] [ j ] [ 0 / 1 ] f[i][j][0/1] f[i][j][0/1]表示取出的归并序列在 a i a_i ai之前结束,而且在 b j b_j bj之前结束,最后取出的是哪个
设想一下,我们每次都是从串 a a a拿出一段 [ i , i + k ] [i,i+k] [i,i+k],再去拿 b j b_j bj
或者先从串 b b b拿出一段 [ j , j + k ] [j,j+k] [j,j+k]再去拿串 i i i
那么定义 s u m a i suma_i sumai表示 [ i − s u m a i + 1 , i ] [i-suma_i+1,i] [i−sumai+1,i]的字母是互不相同的
当 a i ! = b j a_i!=b_j ai!=bj时,有初始化 f [ i ] [ j ] [ 1 ] = s u m a i f[i][j][1]=suma_i f[i][j][1]=sumai
表示以 a i a_i ai结尾可以选择 s u m a i suma_i sumai个后缀,最后接上 b j b_j bj
对于 s u m b [ ] sumb[] sumb[]同样处理
然后就上前面的普通 d p dp dp即可
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5+10;
const int mod = 998244353;
int f[1009][1009][2];
char a[maxn],b[maxn];
int suma[maxn],sumb[maxn];//i之前连续多少个字符不同
signed main()
{
cin >> ( a+1 ) >> ( b+1 );
int n = strlen( a+1 ), m = strlen( b+1 );
suma[1] = sumb[1] = 1;
for(int i=2;i<=n;i++)
{
if( a[i]!=a[i-1] ) suma[i] = suma[i-1]+1;
else suma[i] = 1;
}
for(int i=2;i<=m;i++)
{
if( b[i]!=b[i-1] ) sumb[i] = sumb[i-1]+1;
else sumb[i] = 1;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
if( a[i]!=b[j] )
f[i][j][1] = suma[i], f[i][j][0] = sumb[j];
}
int ans = 0;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
if( a[i]!=a[i-1] ) f[i][j][0] += f[i-1][j][0];
if( a[i]!=b[j] ) f[i][j][0] += f[i-1][j][1];
if( b[j]!=b[j-1] ) f[i][j][1] += f[i][j-1][1];
if( b[j]!=a[i] ) f[i][j][1] += f[i][j-1][0];
f[i][j][0] %= mod; f[i][j][1] %= mod;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
ans = ( ans+f[i][j][0]+f[i][j][1] )%mod;
cout << ans;
}
然后还有更简单的一种方法
f [ i ] [ j ] [ 0 / 1 ] [ 0 / 1 ] [ 0 / 1 ] f[i][j][0/1][0/1][0/1] f[i][j][0/1][0/1][0/1]当前取到 a a a串的 i i i位置(不一定取了 a i a_i ai),取到 b b b串的 j j j位置(不一定取了 b j b_j bj)
此时最后一个取的字母是 a i a_i ai还是 b j b_j bj,当前状态是否取过 a a a串中的字母,当前状态是否取过 b b b中的字母
这样就可以傻瓜式转移了
因为初始化很方便,直接 f [ i ] [ j ] [ 0 ] [ 1 ] [ 0 ] = f [ i ] [ j ] [ 1 ] [ 0 ] [ 1 ] = 1 f[i][j][0][1][0]=f[i][j][1][0][1]=1 f[i][j][0][1][0]=f[i][j][1][0][1]=1即可
不会混淆,暴力转移就好
#include <bits/stdc++.h>
#define forn(i, n) for (int i = 0; i < int(n); i++)
using namespace std;
const int MOD = 998244353;
int add(int a, int b){
a += b;
if (a >= MOD)
a -= MOD;
if (a < 0)
a += MOD;
return a;
}
int main() {
string s, t;
cin >> s >> t;
int n = s.size(), m = t.size();
vector<vector<vector<vector<int>>>> dp(n + 1, vector<vector<vector<int>>>(m + 1, vector<vector<int>>(2, vector<int>(4, 0))));
int ans = 0;
forn(i, n + 1) forn(j, m + 1)
{
if (i < n) dp[i + 1][j][0][1] = add(dp[i + 1][j][0][1], 1);
if (j < m) dp[i][j + 1][1][2] = add(dp[i][j + 1][1][2], 1);
forn(mask, 4){
if (0 < i && i < n && s[i - 1] != s[i]) dp[i + 1][j][0][mask | 1] = add(dp[i + 1][j][0][mask | 1], dp[i][j][0][mask]);
if (0 < j && i < n && t[j - 1] != s[i]) dp[i + 1][j][0][mask | 1] = add(dp[i + 1][j][0][mask | 1], dp[i][j][1][mask]);
if (0 < i && j < m && s[i - 1] != t[j]) dp[i][j + 1][1][mask | 2] = add(dp[i][j + 1][1][mask | 2], dp[i][j][0][mask]);
if (0 < j && j < m && t[j - 1] != t[j]) dp[i][j + 1][1][mask | 2] = add(dp[i][j + 1][1][mask | 2], dp[i][j][1][mask]);
}
ans = add(ans, dp[i][j][0][3]);
ans = add(ans, dp[i][j][1][3]);
}
printf("%d\n", ans);
return 0;
}