计蒜客 2019ICPC 南昌网络赛 F Megumi With String(后缀自动机)_#define

计蒜客 2019ICPC 南昌网络赛 F Megumi With String(后缀自动机)_子串_02

计蒜客 2019ICPC 南昌网络赛 F Megumi With String(后缀自动机)_子串_03

 

 

大致题意:给你一个字符串S和一个多项式f(x),然后再告诉你一个长度。随机取这样一个长度的字符串T,如果T包含S的某个子串且这个子串的长度为len,那么T的权值就增加f(len),如果包含多个子串,那么权值为他们的和。现在问你这个T的期望权值是多少,且S不断变长,你需要求T在每个S情况下的期望。

这题可以参照 ​​HDU 6405​​ 两题有类似之处。

这里本质上相当于统计每个子串出现的期望,这样就可以去掉重复的子串。而每个子串的期望贡献只与长度有关,所以我们可以直接求每个长度的贡献。对于一个长度为len的子串,在长度为n的T字符串中,其产生的期望为:

                           

计蒜客 2019ICPC 南昌网络赛 F Megumi With String(后缀自动机)_#define_04

然后,我们考虑本质不同的子串长度是多少。我们考虑每次插入一个字符之后计算一次。具体来说,根据后缀自动机的定义,当插入一个字符之后,产生的所有新的子串的长度,是介于其parent节点长度到当前点长度之间的。那么,我们只需要把这一段连续长度的期望加上去即可。直接对刚刚求的各个长度的子串期望求前缀和即可。具体见代码:

#include<bits/stdc++.h>
#define INF 0x3f3f3f3f3f3f3f3fll
#define eps 1e-6
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

const int N = 1e5 + 10;
const int mod = 998244353;
const int inv26 = 729486258;

int tot,cur;
struct node{int ch[26],len,fa;} T[N<<3];
void init(int l){cur=tot=1;memset(T,0,sizeof(T[0])*(++l));}

inline int ins(int x,int id)
{
int p=cur;cur=++tot;T[cur].len=id;
for(;p&&!T[p].ch[x];p=T[p].fa) T[p].ch[x]=cur;
if (!p) {T[cur].fa=1;return cur;}int q=T[p].ch[x];
if (T[p].len+1==T[q].len) {T[cur].fa=q;return cur;}
int np=++tot; memcpy(T[np].ch,T[q].ch,sizeof(T[q].ch));
T[np].fa=T[q].fa; T[q].fa=T[cur].fa=np; T[np].len=T[p].len+1;
for(;p&&T[p].ch[x]==q;p=T[p].fa) T[p].ch[x]=np; return cur;
}

char s[N];
int a[N],sum[N<<1];

int main()
{
int TT; sc(TT);
while(TT--)
{

int l,k,m,n,ans=0;
scc(l,k); scc(n,m);
init((l+m)<<1);
scanf("%s",s+1);
for(int i=0;i<=k;i++) sc(a[i]);
sum[0]=a[0];
for(int i=1,inv=1;i<=min(n,l+m);i++)
{
int ss=0;
for(int j=0,pw=1;j<=k;j++,pw=(LL)pw*i%mod)
ss=(ss+(LL)a[j]*pw%mod)%mod;
inv=(LL)inv*inv26%mod;
sum[i]=(LL)ss*(n-i+1)%mod*inv%mod;
sum[i]=(sum[i]+sum[i-1])%mod;
if (i>l) continue;
int tmp=ins(s[i]-'a',i);
int r=min(T[tmp].len,n);
int l=min(T[T[tmp].fa].len,n);
ans=((ans+sum[r]-sum[l])%mod+mod)%mod;
}
printf("%d\n",ans);
while(m--)
{
scanf("%s",s);
int tmp=ins(s[0]-'a',++l);
int r=min(T[tmp].len,n);
int l=min(T[T[tmp].fa].len,n);
ans=((ans+sum[r]-sum[l])%mod+mod)%mod;
printf("%d\n",ans);
}
}
}