珠宝商

题目描述

Louis.PS 是一名精明的珠宝商,他出售的项链构造独特,很大程度上是因为他的制作方法与众不同。每次 Louis.PS 到达某个国家后,他会选择一条路径去遍历该国的城市。在到达一个城市后,他会使用在这个城市流行的材料制作一颗珠子,并按照城市被访问的顺序将珠子串联做成项链,为了使制作出来的项链不会因为城市之间的竞争而影响销量,路径中同一个城市不会重复出现(因为如果项链中 $A$ 城市的材料比 $B$ 城市的材料使用的多,则项链在 $B$ 城市的宣传可能会受到影响)。经过多年对消费者的调查, Louis.PS 已经掌握了判断一条项链吸引消费者程度的方法,具体来说, Louis.PS 经过调查得出了受消费者欢迎的项链的特征,并基于此制作了一个长项链(Louis.PS 称之为特征项链)。对于一条待售的项链,这条项链在特征项链里出现的次数越多,这条项链就越受消费者欢迎。

考虑到现实情况的复杂性,我们对条件做出适当的简化。 对于每个国家,在某些城市间存在道路直接相连,对于两个不同的城市,有且仅有一条路径连接这两个城市(即国家是连通的,且不存在一个环)。对于每个城市,我们用一个小写字母来表示在这个城市流行的材料类型。这样,我们就可以用一个仅包含小写字母的字符串来表示一条项链, 我们将特征项链所对应的字符串称作特征字符串,设为 $EigenString[1...M]$,$M$ 为特征项链的长度。对于一条项链,假设其对应字符串为 $P[1...L]$,$L$ 为这条项链的长度。如果存在一个正整数$K$, 使$EigenString[K...K+L-1]=P[1..L]$,称这条项链在特征项链中出现了一次。 满足上述 条件的正整数$K$的个数即为这条项链在特征项链的出现次数,记为$Popularity(P)$。

Louis.PS 使用数学中的期望概念来评价一个国家是否适合珠宝的采集,对于一个包含 $N$ 个城市的国家,令 $Str_{u,v}$ 表示沿着从 $u$ 开始,至 $v$ 结束的路径所得到的项链的对应字符串。 ($Str_{u,v}$ 与 $Str_{v,u}$ 表示的串一般不相同),则 $$Expectation=\sum_{u,v} Popularity(Str_{u,v}) / N^2$$ 对于如下的例子(图中实线表示两端点的国家有直接道路相连):

$N=3$,所流行的材料类型分别为 $\tt{a,a,b}$。 CTSC2010 珠宝商_后缀自动机、后缀树和后缀数组 $$Expectation=\dfrac{3+1+2+1+3+1+1+1+2}{9}=\dfrac{5}{3}$$ 对于一个国家, Louis.PS 想知道其 Expectation 的值,但苦于计算期望的工作量太大。作为珠宝店的学徒, 你当然不愿放过难得在老板面前展示自己的机会。

输入输出格式

输入格式:

输入文件$\tt{jewelry.in}$,第一行包含两个整数 $N$,$M$,表示城市个数及特征项链长度。 接下来的 $N-1$ 行, 每行两个整数 $x,y$, 表示城市 $x$ 与城市 $y$ 有直接道路相连。城市由 $1~N$ 进行编号。 接下来的一行,包含一个长度为 $N$,仅包含小写字母的字符串,第 $i$ 位的字符表示在城市 $i$ 流行的原料类型。 最后一行, 包含一个长度为 $M$, 仅包含小写字母的字符串, 表示特征字符串。

输出格式:

输出文件 $\tt{jewelry.out}$ 仅包含一个整数,为 $N^2 * Expectation$。

输入输出样例

输入样例#1: 复制
3 5
1 2
1 3
aab
abaab
输出样例#1: 复制
15

说明

有 $20\%$的数据,满足$M \leq 1000$;

有 $40\%$的数据,满足ܰ$N \leq 8000, M \leq 50000$;

对于 $100\%$的数据,$N,M \leq 50000$。

题解

参照张天扬《后缀自动机及其应用》和SFN1036的题解。

因为是求出现次数,显然可以用到后缀自动机来做。 首先考虑两种不同的暴力做法:

暴力1

枚举每个点作为起点,然后把整棵树dfs一次,求出起点到每个点组成的路径的出现次数。由于sam的转移是\(O(1)\)的,所以这么做总的复杂度是\(O(n^2)\)

暴力2

我们考虑求每个点作为路径的lca时候的贡献。设路径的lca为点Z,那么对于一条路径(X,Y),我们可以将其拆成(X,Z)和(Z,Y)两条路径。

考虑Z上的字符在S中的出现位置。那么(X,Z)这一段在S中出现位置的right一定是Z出现位置的right的子集。同理,把(Z,Y)这一段反过来,那么它在S的反串中出现位置的子集的right也必然是Z在反串中出现位置的子集。那么从Z开始dfs,每次维护当前串在原串和反串的后缀树上的位置。然后我们将两个后缀树从上向下递推一遍就可以求出每一个节点的匹配数量了。最后把后缀树上每一个后缀对应位置的匹配数相乘,就可以得到经过lca的路径在原串中的出现次数之和了。

但注意到有可能X和Y位于同一棵子树内,所以还要对每棵子树再求一次来去重。关于如何实现求每个位置的匹配路径数量,我们可以在正反两串的后缀树上打标记,最后下推到叶节点就好了。这么做因为求每个点贡献的时候都要把后缀树扫一遍,所以总的复杂度是\(O(n^2+nm)\)

再仔细思考一下不难发现暴力2可以通过点分治来优化。因为点分治后,所有分治子树的size和大小为\(O(n\log n)\),所以总的复杂度就是\(O(n \log n+nm)\)

整合

那么现在我们的瓶颈就在于每次扫后缀树时的\(O(m)\)

当前的分治子树size较小时,我们暴力扫后缀树显然是一种浪费。那么我们可以怎么做呢?暴力1!

我们不妨设一个阈值B,当分治子树的size不超过B时我们用做法1,不然就用做法2。 显然size超过B的分治子树只有不超过\(O(⌊n/B⌋)\)个,若我们碰到一个size不超过B的子树就退出的话,也可以证明遍历到的size不超过B的子树只有\(O(⌊n/B⌋)\)个。

解一下方程发现,当B是\(√m\)的时候时间复杂度取到最优。 这样总的复杂度就是\(O((n+m)√m)\),就可以AC啦。

如何在后缀树上打标记?

假设我们建出了S的后缀树。 当我们要加入字符串(Z,Y)的时候,因为后缀树本质是一棵所有后缀组成的trie,于是我们可以从后缀树的根开始往下跳。跳到最终点的时候,我们就在这里打一个标记,然后最后再把所有标记下传到叶节点,那么我们就可以知道每个后缀的前缀匹配了多少字符串。这样就知道正反串的后缀的匹配数,把对应的位置乘起来即可。

这题其实还有一个优化,就是在用方法2去重的时候,也应该按照子树的大小来决定用哪一种方法。而我比较懒,所以就直接用了方法2来去做。


参考bztMinamoto的代码。学习了在后缀树上跳状态。

co int N=1e5;
int n,m;
vector<int> e[N];
char str[N],buf[N]; // city string

int vis[N],maxs[N],siz[N],size,root,sqr; // point divide and conquer
ll ans;

struct Suffix_Automaton{
	int str[N]; // Eigen String
	int last,tot;
	Suffix_Automaton() {last=tot=1;}
	int ch[N][26],fa[N],len[N],pos[N],ref[N],siz[N]; // ref:out->in
	void extend(int c,int po){
		int p=last,cur=last=++tot;
		len[cur]=len[p]+1,pos[cur]=po,ref[po]=cur,siz[cur]=1;
		for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=cur;
		if(!p) fa[cur]=1;
		else{
			int q=ch[p][c];
			if(len[q]==len[p]+1) fa[cur]=q;
			else{
				int clone=++tot;
				memcpy(ch[clone],ch[q],sizeof ch[q]);
				fa[clone]=fa[q],len[clone]=len[p]+1,pos[clone]=pos[q];
				fa[cur]=fa[q]=clone;
				for(;ch[p][c]==q;p=fa[p]) ch[p][c]=clone;
			}
		}
	}
	int cnt[N],ord[N],son[N][26];
	void build(){
		for(int i=1;i<=tot;++i) ++cnt[len[i]];
		for(int i=1;i<=m;++i) cnt[i]+=cnt[i-1]; // edit 1: m
		for(int i=1;i<=tot;++i) ord[cnt[len[i]]--]=i;
		for(int i=tot,p;i;--i){
			p=ord[i];
			siz[fa[p]]+=siz[p];
			son[fa[p]][str[pos[p]-len[fa[p]]]]=p;
		}
	}
	int tag[N];
	void mark(int u,int fa,int now,int len){ // now,len is for fa
		if(!now) return;
		if(len==this->len[now]) now=son[now][::str[u]-'a'];
		else if(str[pos[now]-len]!=::str[u]-'a') now=0;
		if(!now) return;
		++tag[now];
		for(int i=0,v;i<e[u].size();++i){
			if(vis[v=e[u][i]]||v==fa) continue;
			mark(v,u,now,len+1);
		}
	}
	void push(){
		for(int i=1;i<=tot;++i) tag[ord[i]]+=tag[fa[ord[i]]];
	}
}sam1,sam2;


void find_root(int u,int fa){
	siz[u]=1,maxs[u]=0;
	for(int i=0,v;i<e[u].size();++i){
		if(vis[v=e[u][i]]||v==fa) continue;
		find_root(v,u),siz[u]+=siz[v],maxs[u]=max(maxs[u],siz[v]);
	}
	maxs[u]=max(maxs[u],size-siz[u]);
	if(maxs[u]<maxs[root]) root=u;
}
int num,tmp[N];
void get_shipped(int u,int fa){
	tmp[++num]=u;
	for(int i=0,v;i<e[u].size();++i){
		if(vis[v=e[u][i]]||v==fa) continue;
		get_shipped(v,u);
	}
}
void brute_force(int u,int fa,int now){
	now=sam1.ch[now][str[u]-'a'];
	if(!now) return;
	ans+=sam1.siz[now];
	for(int i=0,v;i<e[u].size();++i){
		if(vis[v=e[u][i]]||v==fa) continue;
		brute_force(v,u,now);
	}
}
void work(int u,int fa,int op){
	fill(sam1.tag+1,sam1.tag+sam1.tot+1,0);
	fill(sam2.tag+1,sam2.tag+sam2.tot+1,0);
	int to=str[fa]-'a';
	if(fa) sam1.mark(u,fa,sam1.son[1][to],1),sam2.mark(u,fa,sam2.son[1][to],1);
	else sam1.mark(u,fa,1,0),sam2.mark(u,fa,1,0);
	sam1.push(),sam2.push();
	for(int i=1;i<=m;++i) ans+=(ll)op*sam1.tag[sam1.ref[i]]*sam2.tag[sam2.ref[m-i+1]];
}
void solve(int u){
	if(size<=sqr){
		num=0,get_shipped(u,0);
		for(int i=1;i<=num;++i) brute_force(tmp[i],0,1);
		return; // edit 2
	}
	vis[u]=1,work(u,0,1);
	for(int i=0,v;i<e[u].size();++i){
		if(vis[v=e[u][i]]) continue;
		work(v,u,-1);
	}
	for(int i=0,v,all=size;i<e[u].size();++i){
		if(vis[v=e[u][i]]) continue;
		root=0,size=siz[v]<siz[u]?siz[v]:all-siz[u];
		find_root(v,u),solve(root);
	}
}

int main(){
	sqr=sqrt(read(n)),read(m);
	for(int i=1,u,v;i<n;++i){
		read(u),read(v);
		e[u].push_back(v),e[v].push_back(u);
	}
	scanf("%s",str+1),scanf("%s",buf+1);
	for(int i=1;i<=m;++i) sam1.str[i]=buf[i]-'a',sam1.extend(buf[i]-'a',i);
	reverse(buf+1,buf+m+1);
	for(int i=1;i<=m;++i) sam2.str[i]=buf[i]-'a',sam2.extend(buf[i]-'a',i);
	sam1.build(),sam2.build();
	root=0,maxs[0]=n,size=n;
	find_root(1,0),solve(root);
	printf("%lld\n",ans);
	return 0;
}