假如 n = 2 n=2 n=2,那么这是一个 k m p kmp kmp的裸题
把两个字符串 s 2 s_2 s2放在 s 1 s_1 s1前面拼接成一个串跑 k m p kmp kmp
那么从 n x t nxt nxt数组的 s 2 . l e n g t h ( ) s_2.length() s2.length()后找最大的最长公共前后缀就是答案
然而现在 n n n巨大,每次拼接,跑 k m p kmp kmp复杂度爆炸
朴素 k m p kmp kmp代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+10;
int n,top,kmp[maxn],len[maxn];
string s[maxn];
char a[maxn],temp[maxn];
void KMP(char a[] ,int n )
{
int j = 0;
for(int i=2;i<=n;i++)//kmp[1] = 0
{
while( j&&a[i]!=a[j+1] ) j = kmp[j];
if( a[j+1]==a[i] ) j++;
kmp[i] = j;
}
}
int main()
{
cin >> n;
for(int i=1;i<=n;i++) { cin >> s[i]; len[i] = s[i].length(); }
int last = 0;
for(int i=1;i<=len[1];i++) a[++last] = s[1][i-1];
for(int i=2;i<=n;i++)
{
for(int j=1;j<=len[i];j++) temp[j] = s[i][j-1];
for(int j=1;j<=last;j++) temp[j+len[i]] = a[j];//拼接字符串
KMP( temp,last+len[i] );
int mx = kmp[last+len[i]];
for(int j=mx+1;j<=len[i];j++) a[last+j-mx] = s[i][j-1];
last = last+len[i]-mx;
}
for(int i=1;i<=last;i++) printf("%c",a[i] );
}
但是,观察到每次合并答案串和第 i i i个字符串的目的是为了找到一个最大的 n x t [ i ] nxt[i] nxt[i]
那么这个 n x t [ i ] nxt[i] nxt[i]的长度不会大于 x = m i n ( 答 案 串 长 , 第 i 个 串 长 ) x=min(答案串长,第i个串长) x=min(答案串长,第i个串长)
所以每次只需要截取答案串的前 x x x跑 k m p kmp kmp即可
这样交上去就 W A WA WA了…
比如下面这个测试数据
2
ababa baba
合并之后对 b a b a b a b a babababa babababa求 k m p kmp kmp
结果…当 i = 8 i=8 i=8时, n x t [ i ] = 6.... nxt[i]=6.... nxt[i]=6....因为直接把两个串的公共部分合并进去了…
怎样来避免这个问题??我们其实只需要在拼接的中间加上一个特殊字符 w w w即可
考虑最长公共前后缀的前缀为 [ 1 , l e n ] [1,len] [1,len],后缀为 [ n − l e n + 1 , n ] [n-len+1,n] [n−len+1,n]
设 w w w出现在 i d id id位置,若 i d ∈ [ 1 , l e n ] id\in[1,len] id∈[1,len]
那么 i d = n − l e n + i d id=n-len+id id=n−len+id
但是 i d ! = n − l e n + i d id!=n-len+id id!=n−len+id,因为 l e n len len一定小于 n n n
所以不存在公共前后缀包含 w w w
即:在分界线处加入 w w w不会对答案产生影响
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6+10;
int n,top,kmp[maxn],len[maxn];
string s[maxn];
char a[maxn],temp[maxn];
void KMP(char a[] ,int n )
{
int j = 0;
for(int i=2;i<=n;i++)//kmp[1] = 0
{
while( j&&a[i]!=a[j+1] ) j = kmp[j];
if( a[j+1]==a[i] ) j++;
kmp[i] = j;
}
}
int main()
{
cin >> n;
for(int i=1;i<=n;i++) { cin >> s[i]; len[i] = s[i].length(); }
int last = 0;
for(int i=1;i<=len[1];i++) a[++last] = s[1][i-1];
for(int i=2;i<=n;i++)
{
for(int j=1;j<=len[i];j++) temp[j] = s[i][j-1];
temp[len[i]+1] = '.';//添加特殊符号
int x = min( last,len[i] );
for(int j=1;j<=x;j++) temp[1+j+len[i]] = a[last-x+j];//拼接字符串
KMP( temp,1+x+len[i] );
int mx = kmp[1+x+len[i]];
for(int j=mx+1;j<=len[i];j++) a[last+j-mx] = s[i][j-1];
last = last+len[i]-mx;
}
for(int i=1;i<=last;i++) printf("%c",a[i] );
}