题目

[CTSC2010]珠宝商_SAM

#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
	char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=1e5+5;
struct E{int v,nxt;}e[maxn];
int head[maxn>>1],sum[maxn>>1],mx[maxn>>1],vis[maxn>>1];
int T,rt,n,m,B,num;LL ans;
char S[maxn>>1],a[maxn>>1];
struct SAM {
	int len[maxn],t[maxn>>1],fa[maxn],pos[maxn],nxt[maxn][26],g[maxn];
	int son[maxn][26],s[maxn>>1],tax[maxn>>1],A[maxn],sz[maxn];
	int lst,cnt;
	inline void ins(int c,int o) {
		int p=++cnt,f=lst;lst=p;
		len[p]=len[f]+1,sz[p]=1,pos[p]=o;t[o]=p;
		while(f&&!son[f][c]) son[f][c]=p,f=fa[f];
		if(!f) {fa[p]=1;return;}
		int x=son[f][c];
		if(len[f]+1==len[x]) {fa[p]=x;return;}
		int y=++cnt;
		len[y]=len[f]+1,fa[y]=fa[x],fa[x]=fa[p]=y;
		for(re int i=0;i<26;i++) son[y][i]=son[x][i];
		while(f&&son[f][c]==x) son[f][c]=y,f=fa[f];
	}
	inline void build() {
		lst=cnt=1;
		for(re int i=1;i<=m;i++) ins(s[i],i);
		for(re int i=1;i<=cnt;i++) tax[len[i]]++;
		for(re int i=1;i<=m;i++) tax[i]+=tax[i-1];
		for(re int i=1;i<=cnt;i++) A[tax[len[i]]--]=i;
		for(re int i=cnt;i;--i) {
			int x=A[i];
			sz[fa[x]]+=sz[x];
			if(!pos[fa[x]]) pos[fa[x]]=pos[x];
			nxt[fa[x]][s[pos[x]-len[fa[x]]]]=x;
		}
	}
	inline void clear() {
		for(re int i=1;i<=cnt;i++) g[i]=0;
	}
	inline void update() {
		for(re int i=1;i<=cnt;i++) g[A[i]]+=g[fa[A[i]]];
	}
	void match(int x,int fa,int now,int l) {
		if(l==len[now]) now=nxt[now][a[x]];
		else if(s[pos[now]-l]!=a[x]) now=0;
		if(!now) return;g[now]++;l++;
		for(re int i=head[x];i;i=e[i].nxt) {
			if(vis[e[i].v]||e[i].v==fa) continue;
			match(e[i].v,x,now,l);
		}
	}
}p[2];
inline void add(int x,int y) {
	e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void getroot(int x,int fa) {
	sum[x]=1;mx[x]=0;
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||e[i].v==fa) continue;
		getroot(e[i].v,x);sum[x]+=sum[e[i].v];
		mx[x]=max(mx[x],sum[e[i].v]);
	}
	mx[x]=max(mx[x],T-sum[x]);
	if(mx[x]<mx[rt]) rt=x;
}
void calc(int x,int now,int fa) {
	now=p[0].son[now][a[x]];
	if(!now) return;
	ans+=p[0].sz[now];
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||e[i].v==fa) continue;
		calc(e[i].v,now,x);
	}
}
void solve(int x,int fa) {
	calc(x,1,0);
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]||e[i].v==fa) continue;
		solve(e[i].v,x);
	}
}
void getdis(int x,int fa) {
	p[0].clear(),p[1].clear();
	p[0].match(x,0,1,0);p[0].update();
	p[1].match(x,0,1,0);p[1].update();
	for(re int i=1;i<=m;i++)
		ans+=1ll*p[0].g[p[0].t[i]]*p[1].g[p[1].t[m-i+1]];
}
void del(int x,int fa) {
	p[0].clear(),p[1].clear();
	p[0].match(x,0,p[0].nxt[1][a[fa]],1);
	p[1].match(x,0,p[1].nxt[1][a[fa]],1);
	p[0].update(),p[1].update();
	for(re int i=1;i<=m;i++)
		ans-=1ll*p[0].g[p[0].t[i]]*p[1].g[p[1].t[m-i+1]];
}
void dfs(int x) {
	if(sum[x]<=B) {solve(x,0);return;}
	getdis(x,0);vis[x]=1;
	for(re int i=head[x];i;i=e[i].nxt) 
		if(!vis[e[i].v]) del(e[i].v,x);
	for(re int i=head[x];i;i=e[i].nxt) {
		if(vis[e[i].v]) continue;
		T=sum[e[i].v];rt=0;getroot(e[i].v,0);dfs(rt);
	}
}
int main() {
	n=read(),m=read();B=std::ceil(std::sqrt(n));
	for(re int x,y,i=1;i<n;i++)
		x=read(),y=read(),add(x,y),add(y,x);
	scanf("%s",a+1),scanf("%s",S+1);
	for(re int i=1;i<=n;i++) a[i]-='a';
	for(re int i=1;i<=m;i++) S[i]-='a';
	for(re int i=1;i<=m;i++) p[0].s[i]=S[i],p[1].s[m-i+1]=S[i];
	p[0].build(),p[1].build();
	mx[0]=n+1,T=n,getroot(1,0);dfs(rt);
	std::cout<<ans;
	return 0;
}