simple 点分治



​CF914E Palindromes in a Tree 题目传送门​​。

题意简述:给一棵树,节点上有字符。对于每个点,求有多少条经过该点的路径满足该路径上出现奇数次的字符最多有 \(1\) 个。

\(n\leq 2\times 10^5\),字符集大小为 \(20\)。 

话说这个标题有点误导人,一开始想严格意义上的回文串觉得很不可做。



树上路径统计,不难想到点分治,而判断最多有一个字符出现奇数次可以用状压。具体地,对于一个分治中心 \(r\),用一个 \(20\) 位二进制数统计它的所有子节点 \(u\) 到 \(r\)(不包括 \(r\))的路径上每个字符的出现次数的奇偶性,记为 \(msk_u\)。

同时,如果我们知道了有多少条路径以端点 \(u\) 结尾并经过分治中心,设为 \(p_u\),那么从 \(u\) 到 \(r\) 上的所有点(除了 \(r\) 需要单独处理)都要加上 \(p_u\)。由于是离线所以使用树上差分即可。

对于每个 \(u\),我们将只到分治中心的路径和跨过分治中心的路径分开来考虑,因为在根节点 \(r\) 处,前者只会被计算一次,而后者会被计算两次。记 \(c=msk_u\oplus 2^{s_r}\)。

  • 只到分治中心:如果存在 \(k\) 使得 \(c\oplus 2^k=0\) 或 \(c\) 本身就为 \(0\),那么路径 \((u,r)\) 是合法的。将 \(p_u\) 和 \(ans_r\) 增加 \(1\)。
  • 跨过分治中心:枚举出现次数为奇数的字符 \(k\),对于所有使得 \(c\oplus 2^k\oplus msk_v=0\)(即 \(c\oplus 2^k=msk_v\))的 \(v\),\((u,v)\) 都是一条合法路径,可以通过用两个桶分别记录在整棵 \(r\) 的子树内 和 在当前点 \(u\) 所在的子树中(为了除去不经过分治中心的答案)\(msk_v=x\) 的 \(v\) 数量,分别记为 \(buc1_x\) 和 \(buc2_x\),那么 \(v\) 的数量即为 \(buc1_{c\oplus k}-buc2_{c\oplus k}\)。不要忘记考虑没有出现次数为奇数的情况。

综上,我们有了一个 \(\mathcal{O}(n\log n\mathbf|Σ|)\) 的算法:点分治,统计答案时枚举 \(i=0\) 和 \(i=2^k\),记 \(c=msk_u\oplus 2^{s_r}\oplus i\)。若 \(c=0\),则将标记记为 \(1\)。同时将 \(p_u\) 加上 \(buc1_c-buc2_c\)。枚举结束后,若标记为 \(1\),则将 \(p_u\) 和 \(ans_r\) 都加上 \(1\)。最后使用树上差分统计非分治中心的所有节点对答案的贡献。

记得将桶清零。此外,不需要在一开始将 \(buc1_0\) 设为 \(1\)。

#include<bits/stdc++.h>
using namespace std;

#define db double
#define ll long long
#define ld long double
#define ull unsigned long long

#define pii pair <int,int>
#define fi first
#define se second
#define pb emplace_back
#define mem(x,v,s) memset(x,v,sizeof(x[0])*(s))
#define cpy(x,y,s) memcpy(x,y,sizeof(x[0])*(s))

const int N=2e5+5;
const int S=20;

int n,vis[N];
char s[N];
vector <int> e[N];

ll ans[N],buc[1<<S],insi[1<<S],p[N];
vector <pii> res[N];
int root,mx[N],sz[N];

void findr(int id,int f,int tot){
mx[id]=sz[id]=1;
for(int it:e[id])
if(it!=f&&!vis[it]){
findr(it,id,tot),sz[id]+=sz[it];
mx[id]=max(mx[id],sz[it]);
}
mx[id]=max(mx[id],tot-sz[id]);
if(mx[id]<mx[root])root=id;
}
void presum(int id,int f){
for(int it:e[id])
if(it!=f&&!vis[it])
presum(it,id),
p[id]+=p[it];
if(f)ans[id]+=p[id];
}
void calc(int id,int f,int pre,int anc){
res[anc].pb((pii){pre^=(1<<s[id]-'a'),id}),p[id]=0;
for(int it:e[id])if(it!=f&&!vis[it])calc(it,id,pre,anc);
}
void divide(int id){
int msk=1<<s[id]-'a';
vis[id]=1,p[id]=0;
for(int it:e[id])
if(!vis[it]){
res[it].clear(),calc(it,id,0,it);
for(pii c:res[it])buc[c.fi]++;
}
ll sum=0;
for(int it:e[id]){
if(vis[it])continue;
for(pii c:res[it])insi[c.fi]++;
for(pii c:res[it]){
bool have=0;
for(int k=-1;k<20;k++){
int msk2=(k<0?0:1<<k),r=c.fi^msk^msk2;
if(!r)have=1;
ll d=buc[r]-insi[r];
sum+=d,p[c.se]+=d;
} if(have)p[c.se]++,ans[id]++;
} for(pii c:res[it])insi[c.fi]=0;
} presum(id,0),ans[id]+=sum>>1,assert(sum%2==0);
for(int it:e[id])for(pii c:res[it])buc[c.fi]=0;
for(int it:e[id])if(!vis[it])root=0,findr(it,id,sz[it]),divide(root);
}


int main(){
cin>>n,mx[0]=n+1;
for(int i=1;i<n;i++){
int u,v; cin>>u>>v;
e[u].pb(v),e[v].pb(u);
} scanf("%s",s+1);
findr(1,0,n),divide(root);
for(int i=1;i<=n;i++)cout<<ans[i]+1<<" ";
cout<<endl;
return 0;
}