D. MEX Tree(容斥+双指针)

a n s i ans_i ansi为异或值为 i i i的路径数。

f i f_i fi为异或值 ≥ i \ge i i的路径数。

g i g_i gi为异或值 > i >i >i​的路径数。

s z i sz_i szi表示结点 i i i的子树大小。

容易发现: g i = f i + 1 g_i=f_{i+1} gi=fi+1

考虑用容斥计算 a n s i = f i − g i ans_i=f_i-g_i ansi=figi

如果我们能计算出 g i , i ∈ [ 1 , n ] g_i,i\in [1,n] gi,i[1,n] 我们就可以递推得到 a n s i ans_i ansi​。

f 0 f_0 f0 就是整棵树的路径数, a n s 0 ans_0 ans0 所有 0 0 0的儿子子树内的路径数之和。

f 0 = n ( n − 1 ) 2 , a n s 0 = ∑ v ∈ s o n 0 s z v ( s z v − 1 ) 2 \large f_0=\dfrac{n(n-1)}{2},ans_0=\sum\limits_{v\in son_0}\dfrac{sz_v(sz_v-1)}{2} f0=2n(n1),ans0=vson02szv(szv1)

g 0 = f 0 − a n s 0 = f 1 g_0=f_0-ans_0=f_1 g0=f0ans0=f1

所以接下来的关键就是求 g i , ∈ [ 1 , n ] g_i, \in [1,n] gi,[1,n]

考虑用双指针。

用两个指针 l , r l,r l,r​ 分别表示 g i g_i gi​ 对应的链的端点。

则答案就是: s z l × s z r sz_l\times sz_r szl×szr

显然这条链包含 0 , 1 , 2 , … , i 0,1,2,\dots ,i 0,1,2,,i 结点,且 i = l i=l i=l i = r i=r i=r

我们考虑这条链如何扩展到 g i g_{i} gi

显然要满足 g i g_{i} gi​​​必须满足 g i − 1 g_{i-1} gi1​​​,所以 g i g_i gi​​​ 必须在 p a t h ( l , r ) path(l,r) path(l,r)​​​这条路径上扩展,也就是说 i i i​​​ 往上走必须走到 l l l​​​或者 r r r​​​。

如果 i i i​走步到 l l l​或者 r r r​,说明无解 g i = 0 g_{i}=0 gi=0,则 a n s j = 0 , ( j   > i ) ans_j=0,(j\ >i) ansj=0,(j >i)直接 b r e a k break break

否则更新 p a t h ( l , r ) path(l,r) path(l,r)

需要注意的是如果 l = 0 l=0 l=0​或者 r = 0 r=0 r=0​,计算贡献时要 s z 0 sz_0 sz0​减掉对应的那颗子树,注意是永久性修改,因为链是继承关系。

时间复杂度: O ( n ) O(n) O(n)

#include<bits/stdc++.h>
#define ll long long
const int N = 2e5+10;
int t,n,fa[N],cov[N];
ll sz[N],ans[N],fl,p,q;
std::vector<int> g[N];
void dfs(int x,int f){
	fa[x] = f ;
	sz[x] = 1 ; 
	for(auto p:g[x]){
		if(p==f) continue;
		dfs(p,x);
		sz[x]+=sz[p];
	}
}
void add(int x){
	if(cov[x]) return ;
	if(!cov[fa[x]]) add(fa[x]); 
	cov[x] = 1;
	if(fa[x]!=q&&fa[x]!=p){
		fl = 0 ;
		return ;
	}
	else{
		 if(fa[x]==p) p=x;
		 else q=x;
		 if(fa[x]==0) sz[0]-=sz[x];
	}
}
int main(){
	scanf("%d",&t);
	while(t--){
		scanf("%d",&n);
		for(int i=0;i<=n;++i) g[i].clear(),ans[i]=0,cov[i]=0;
		for(int i=1;i<=n-1;++i){
			int u,v;
			scanf("%d%d",&u,&v);
			g[u].push_back(v);
			g[v].push_back(u);
		}
		dfs(0,-1);//fa sz
		ll sum = 1 ;
		for(auto p:g[0]){
			ans[0]+=sz[p]*(sz[p]-1)/2;//ans_0
			ans[1]+=sz[p]*sum;// mex>=1
			sum+=sz[p];
		}
		fl = 1 , p = q = 0; cov[0] = 1;
		for(int i=1;i<n;++i){
			add(i);
			if(!fl) break;
			ans[i] -= 1ll*sz[p]*sz[q];  // -   mex >i
			ans[i+1] = 1ll*sz[p]*sz[q];		} 
		for(int i=0;i<=n;++i)
			printf("%lld ",ans[i]);
		puts("");
	}
	return 0;
}

·