LINK

问题相当于求:串 a a a的某一段和串 b b b的某一段的归并序列和。


如果只有两段子串 a , b a,b a,b,求他们的归并序列数量,怎么做??

定义 f [ i ] [ j ] [ 0 / 1 ] f[i][j][0/1] f[i][j][0/1]表示取出串 a a a最近一次是 a i a_i ai,取出的串 b b b最近一次是 b j b_j bj,且当前取出的是 a / b a/b a/b串的字母

转移方程非常好写。初始化 f [ 1 ] [ 0 ] [ 0 ] = f [ 0 ] [ 1 ] [ 1 ] = 1 f[1][0][0]=f[0][1][1]=1 f[1][0][0]=f[0][1][1]=1即可

转移方程

	for(int i=1;i<=n;i++)
	for(int j=1;j<=m;j++)
	{
		if( a[i]!=a[i-1] )	f[i][j][0] += f[i-1][j][0];
		if( a[i]!=b[j] )	f[i][j][0] += f[i-1][j][1];
		if( b[j]!=b[j-1] )	f[i][j][1] += f[i][j-1][1];
		if( b[j]!=a[i] )	f[i][j][1] += f[i][j-1][0];
		f[i][j][0] %= mod;	f[i][j][1] %= mod;
	}

然后现在问题升级,求任意段的归并序列和

那我们可以把这两段看成多起点 d p dp dp

之前是只能从串 a a a的一号位置,串 b b b的一号位置开始

现在可以从任意位置,那是不是说初始化所有的

f [ i ] [ j ] [ 0 ] = f [ i ] [ j ] [ 1 ] f[i][j][0]=f[i][j][1] f[i][j][0]=f[i][j][1]当且仅当 a i ! = b j a_i!=b_j ai!=bj

然后我就在这里卡了一个小时

这样算答案会变小,因为我们方程的定义是取出串 a a a最近的是 a i a_i ai,取出串 b b b最近的是 b j b_j bj

虽然现在 a i ! = b j a_i!=b_j ai!=bj,但是我可以取出 a i , a i + 1 a_i,a_{i+1} ai,ai+1再取出 b j b_j bj,此时 a i + 1 ! = b j a_{i+1}!=b_j ai+1!=bj

这样也是符合这种定义的,然而我们根本没有算进去!!!

所以我们改变 f f f数组的初始化

定义 f [ i ] [ j ] [ 0 / 1 ] f[i][j][0/1] f[i][j][0/1]表示取出的归并序列在 a i a_i ai之前结束,而且在 b j b_j bj之前结束,最后取出的是哪个

设想一下,我们每次都是从串 a a a拿出一段 [ i , i + k ] [i,i+k] [i,i+k],再去拿 b j b_j bj

或者先从串 b b b拿出一段 [ j , j + k ] [j,j+k] [j,j+k]再去拿串 i i i

那么定义 s u m a i suma_i sumai表示 [ i − s u m a i + 1 , i ] [i-suma_i+1,i] [isumai+1,i]的字母是互不相同的

a i ! = b j a_i!=b_j ai!=bj时,有初始化 f [ i ] [ j ] [ 1 ] = s u m a i f[i][j][1]=suma_i f[i][j][1]=sumai

表示以 a i a_i ai结尾可以选择 s u m a i suma_i sumai个后缀,最后接上 b j b_j bj

对于 s u m b [ ] sumb[] sumb[]同样处理

然后就上前面的普通 d p dp dp即可

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5+10;
const int mod = 998244353;
int f[1009][1009][2];
char a[maxn],b[maxn];
int suma[maxn],sumb[maxn];//i之前连续多少个字符不同 
signed main()
{
	cin >> (  a+1 ) >> ( b+1 );
	int n = strlen( a+1 ), m = strlen( b+1 );
	suma[1] = sumb[1] = 1;
	for(int i=2;i<=n;i++)
	{
		if( a[i]!=a[i-1] )	suma[i] = suma[i-1]+1;
		else	suma[i] = 1;
	}
	for(int i=2;i<=m;i++)
	{
		if( b[i]!=b[i-1] )	sumb[i] = sumb[i-1]+1;
		else	sumb[i] = 1;		
	}
	for(int i=1;i<=n;i++)
	for(int j=1;j<=m;j++)
	{
		if( a[i]!=b[j] )	
			f[i][j][1] = suma[i], f[i][j][0] = sumb[j];
	}
	int ans = 0;
	for(int i=1;i<=n;i++)
	for(int j=1;j<=m;j++)
	{
		if( a[i]!=a[i-1] )	f[i][j][0] += f[i-1][j][0];
		if( a[i]!=b[j] )	f[i][j][0] += f[i-1][j][1];
		if( b[j]!=b[j-1] )	f[i][j][1] += f[i][j-1][1];
		if( b[j]!=a[i] )	f[i][j][1] += f[i][j-1][0];
		f[i][j][0] %= mod;	f[i][j][1] %= mod;
	}
	for(int i=1;i<=n;i++)
	for(int j=1;j<=m;j++)
		ans = ( ans+f[i][j][0]+f[i][j][1] )%mod;
	cout << ans;
} 

然后还有更简单的一种方法

f [ i ] [ j ] [ 0 / 1 ] [ 0 / 1 ] [ 0 / 1 ] f[i][j][0/1][0/1][0/1] f[i][j][0/1][0/1][0/1]当前取到 a a a串的 i i i位置(不一定取了 a i a_i ai),取到 b b b串的 j j j位置(不一定取了 b j b_j bj)

此时最后一个取的字母是 a i a_i ai还是 b j b_j bj,当前状态是否取过 a a a串中的字母,当前状态是否取过 b b b中的字母

这样就可以傻瓜式转移了

因为初始化很方便,直接 f [ i ] [ j ] [ 0 ] [ 1 ] [ 0 ] = f [ i ] [ j ] [ 1 ] [ 0 ] [ 1 ] = 1 f[i][j][0][1][0]=f[i][j][1][0][1]=1 f[i][j][0][1][0]=f[i][j][1][0][1]=1即可

不会混淆,暴力转移就好

#include <bits/stdc++.h>

#define forn(i, n) for (int i = 0; i < int(n); i++)

using namespace std;

const int MOD = 998244353;

int add(int a, int b){
	a += b;
	if (a >= MOD)
		a -= MOD;
	if (a < 0)
		a += MOD;
	return a;
}

int main() {
	string s, t;
	cin >> s >> t;
	int n = s.size(), m = t.size();
	vector<vector<vector<vector<int>>>> dp(n + 1, vector<vector<vector<int>>>(m + 1, vector<vector<int>>(2, vector<int>(4, 0))));
	int ans = 0;
	forn(i, n + 1) forn(j, m + 1)
	{
		if (i < n) dp[i + 1][j][0][1] = add(dp[i + 1][j][0][1], 1);
		if (j < m) dp[i][j + 1][1][2] = add(dp[i][j + 1][1][2], 1);
		forn(mask, 4){
			if (0 < i && i < n && s[i - 1] != s[i]) dp[i + 1][j][0][mask | 1] = add(dp[i + 1][j][0][mask | 1], dp[i][j][0][mask]);
			if (0 < j && i < n && t[j - 1] != s[i]) dp[i + 1][j][0][mask | 1] = add(dp[i + 1][j][0][mask | 1], dp[i][j][1][mask]);
			if (0 < i && j < m && s[i - 1] != t[j]) dp[i][j + 1][1][mask | 2] = add(dp[i][j + 1][1][mask | 2], dp[i][j][0][mask]);
			if (0 < j && j < m && t[j - 1] != t[j]) dp[i][j + 1][1][mask | 2] = add(dp[i][j + 1][1][mask | 2], dp[i][j][1][mask]);
		}
		ans = add(ans, dp[i][j][0][3]);
		ans = add(ans, dp[i][j][1][3]);
	}
	printf("%d\n", ans);
	return 0;
}