以每个叶子节点为开头进行dfs遍历,将遍历到的串全部加入建立广义SAM,结果即为本质不同的字符串个数
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
#define int long long
using namespace std;
const int maxn=2e6+10;
int n,m,tot,len[maxn],fa[maxn],ch[maxn][26],c;
int head[maxn<<1],nex[maxn<<1],to[maxn<<1];
int cnt[maxn];
char s[maxn];
int vis[maxn];
int res;
int col[maxn];
int k=1;
int ecnt;
void add(int x,int y) {
to[++ecnt]=y;
nex[ecnt]=head[x];
head[x]=ecnt;
}
int Ins(int c,int last) {
int p=last;
if(ch[p][c]) {
int q=ch[p][c];
if(len[p]+1==len[q]) {
return q;
} else {
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
return nq;
}
}
int np=++tot;
len[np]=len[p]+1;
for(; p&&!ch[p][c]; p=fa[p])ch[p][c]=np;
if(!p)fa[np]=1;
else {
int q=ch[p][c];
if(len[p]+1==len[q])fa[np]=q;
else {
int nq=++tot;
vis[nq]=k;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
}
}
return np;
}
int id[maxn];
inline void prepare() {
for (int i=1; i<=tot; i++) cnt[len[i]]++;
for (int i=1; i<=tot; i++) cnt[i]+=cnt[i-1];
for (int i=1; i<=tot; i++) id[cnt[len[i]]--]=i;
for (int i=tot; i>=1; i--) {
int X=id[i];
if (vis[fa[X]]==0) vis[fa[X]]=vis[X];
else if (vis[fa[X]]!=vis[X]) vis[fa[X]]=-1;
}
}
void dfs(int x,int fa,int last) {
int tmp=Ins(col[x],last);
for(int i=head[x]; i; i=nex[i]) {
int y=to[i];
if(y==fa)continue;
dfs(y,x,tmp);
}
}
int in[maxn];
signed main() {
std::ios::sync_with_stdio(false);
int n,c;
cin>>n>>c;
tot=1;
for(int i=1; i<=n; i++) {
cin>>col[i];
}
for(int i=1; i<n; i++) {
int x,y;
cin>>x>>y;
in[x]++;
in[y]++;
add(x,y);
add(y,x);
}
for(int i=1; i<=n; i++) {
if(in[i]==1)
dfs(i,i,1);
}
for(int i=2; i<=tot; i++) {
res+=len[i]-len[fa[i]];
}
cout<<res<<endl;
}
SP8093 JZPGYZ - Sevenk Love Oimaster
给定 n 个模板串,以及 m 个查询串
依次查询每一个查询串是多少个模板串的子串
先把所有串建立广义SAM,然后让每个串在SAM上跑,记录每个节点经历多少个串,若某串经历某点、那其parent树上所有点点也应当计算贡献、最后用询问串在SAM上跑即可
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#define int long long
using namespace std;
const int N=2e6+10;
int n,m,q,cnt[N],tot,last,len[N],fa[N],ch[N][26],ans;
string ss[N];
int col[N];
int vis[N];
string s;
int Ins(int c,int last) {
int p=last;
if(ch[p][c]) {
int q=ch[p][c];
if(len[p]+1==len[q]) {
return q;
} else {
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
return nq;
}
}
int np=++tot;
len[np]=len[p]+1;
for(; p&&!ch[p][c]; p=fa[p])ch[p][c]=np;
if(!p)fa[np]=1;
else {
int q=ch[p][c];
if(len[p]+1==len[q])fa[np]=q;
else {
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
}
}
return np;
}
signed main() {
cin>>m>>q;
tot=1;
for(int i=1; i<=m; i++) {
last=1;
cin>>ss[i];
n=ss[i].length();
for(int j=0; j<n; j++)
last=Ins(ss[i][j]-'a',last);
}
for(int i=1; i<=m; i++) {
int pos=1;
int len=ss[i].length();
for(int j=0; j<len; j++) {
pos=ch[pos][ss[i][j]-'a'];
int p=pos;
for(; p&&col[p]!=i; p=fa[p]) {
vis[p]++;
col[p]=i;
}
}
}
while(q--) {
cin>>s;
int len=s.length();
int pos=1;
int k;
for(k=0; k<len; k++) {
if(ch[pos][s[k]-'a']) {
pos=ch[pos][s[k]-'a'];
} else break;
}
if(k==len)cout<<vis[pos]<<endl;
else cout<<0<<endl;
}
}
P4081 [USACO17DEC]Standing Out from the Herd P
定义一个字符串的「独特值」为只属于该字符串的本质不同的非空子串的个数,求每个字符串的独特值。
建立广义SAM,在插入时,若发现该节点已有,则赋值该节点为-1,结束后子节点转移父节点、输出结果即可
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#define int long long
using namespace std;
const int maxn=1e6+10;
int n,m,tot,last,len[maxn],fa[maxn],ch[maxn][26];
int cnt[maxn];
char s[maxn];
int vis[maxn];
int res[maxn];
int k=1;
int Ins(int c,int last) {
int p=last;
if(ch[p][c]) {
int q=ch[p][c];
if(len[p]+1==len[q]) {
vis[q]=-1;
return q;
} else {
int nq=++tot;
vis[nq]=-1;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
return nq;
}
}
int np=++tot;
len[np]=len[p]+1;
vis[np]=k;
for(; p&&!ch[p][c]; p=fa[p])ch[p][c]=np;
if(!p)fa[np]=1;
else {
int q=ch[p][c];
if(len[p]+1==len[q])fa[np]=q;
else {
int nq=++tot;
vis[nq]=k;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
}
}
return np;
}
int id[maxn];
inline void prepare() {
for (int i=1; i<=tot; i++) cnt[len[i]]++;
for (int i=1; i<=tot; i++) cnt[i]+=cnt[i-1];
for (int i=1; i<=tot; i++) id[cnt[len[i]]--]=i;
for (int i=tot; i>=1; i--) {
int X=id[i];
if (vis[fa[X]]==0) vis[fa[X]]=vis[X];
else if (vis[fa[X]]!=vis[X]) vis[fa[X]]=-1;
}
}
signed main() {
// freopen("P4081_7.in","r",stdin);
scanf("%lld",&m);
tot=1;
for(int i=1; i<=m; i++) {
last=1;
scanf("%s",s);
n=strlen(s);
for(int j=0; j<n; j++)
last=Ins(s[j]-'a',last);
k++;
}
prepare();
for(int i=2; i<=tot; i++) {
res[vis[i]]+=len[i]-len[fa[i]];
}
for(int i=1; i<=m; i++) {
printf("%lld\n",res[i]);
}
}
CF204E Little Elephant and Strings
对于每个字符串a[i]寻找有序对(l,r)
即的子串[l...r]是字符串数组a中至少k个字符串的子串
考虑建立广义SAM、原串在SAM上跑记录结点经历多少串
处理出每个结点的贡献
最后再对每个串再跑一遍计算贡献
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#define int long long
using namespace std;
const int maxn=2e5+10;
int n,m,tot,last,len[maxn],fa[maxn],ch[maxn][26];
string s[maxn];
int vis[maxn<<1];
int col[maxn];
int Ins(int c,int last) {
int p=last;
if(ch[p][c]) {
int q=ch[p][c];
if(len[p]+1==len[q]) {
return q;
} else {
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
return nq;
}
}
int np=++tot;
len[np]=len[p]+1;
for(; p&&!ch[p][c]; p=fa[p])ch[p][c]=np;
if(!p)fa[np]=1;
else {
int q=ch[p][c];
if(len[p]+1==len[q])fa[np]=q;
else {
int nq=++tot;
len[nq]=len[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for(; p&&ch[p][c]==q; p=fa[p])ch[p][c]=nq;
}
}
return np;
}
int ln;
int id[maxn];
int cnt[maxn];
int sum[maxn];
int maxlen;
signed main() {
cin>>n>>m;
tot=1;
for(int i=1; i<=n; i++) {
last=1;
cin>>s[i];
ln=s[i].length();
for(int j=0; j<ln; j++)
last=Ins(s[i][j]-'a',last);
}
for(int i=1; i<=n; i++) {
ln=s[i].length();maxlen=max(maxlen,ln);
int pos=1;
for(int j=0; j<ln; j++) {
pos=ch[pos][s[i][j]-'a'];
int p=pos;
for(; p>=2&&col[p]!=i; p=fa[p]) {
vis[p]++;
col[p]=i;
}
}
}
for (int i=1; i<=tot; i++) cnt[len[i]]++;
for (int i=1; i<=maxlen; i++) cnt[i]+=cnt[i-1];
for (int i=1; i<=tot; i++) id[cnt[len[i]]--]=i;
for (int i=2; i<=tot; i++) {
if(vis[id[i]]>=m) {
sum[id[i]]+=len[id[i]]-len[fa[id[i]]];
}
sum[id[i]]+=sum[fa[id[i]]];
}
for(int i=1; i<=n; i++) {
ln=s[i].length();
int length=0;
int pos=1;
int res=0;
for(int j=0; j<ln; j++) {
length++;
pos=ch[pos][s[i][j]-'a'];
int p=pos;
if(vis[p]>=m) {
res+=length-len[fa[p]];
}
res+=sum[fa[p]];
}
cout<<res<<" ";
}
cout<<"\n";
}