题目
考场上送\(75pts\)真实良心,正解不难;考虑直接对于每一个点算割掉多少条边能使得这个点成为重心,不难发现对于一个不是重心的点,我们要割掉的那条边一定在那个大于\(\lfloor \frac{n}{2} \rfloor\)的子树里面,而最大子树割掉之后可能就不是最大的了,但新的最大子树只可能是原来的次大子树,推一下柿子要割掉的子树大小需要在\([2A-n,n-2B]\)之间,其中\(A\)为最大子树,\(B\)为次大子树
于是我们先求一个重心作为根,这样所有非重心节点的最大子树就会跨过这个根,在dfs的过程就能更新子树的大小,用树状数组维护一下就好了,由于需要排除子树内部的情况,所以还需要一个线段树合并;至于重心节点不超过两个,暴力求一下就好
代码
inline int read() {
char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=3e5+5;
const int M=maxn*25;
int l[M],r[M],d[M];
struct E{int v,nxt;}e[maxn<<1];
int T,n,num,__,cnt;LL ans=0;
int head[maxn],sum[maxn],rt[maxn],mx[maxn],c[maxn],t[maxn],sz[maxn],col[maxn];
inline void add(int x,int v) {
for(re int i=x;i<=n;i+=i&(-i)) c[i]+=v;
}
inline int ask(int x) {
int nw=0;
for(re int i=x;i;i-=i&(-i)) nw+=c[i];
return nw;
}
inline void add_E(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void Dfs(int x,int fa) {
sum[x]=1;mx[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(e[i].v==fa) continue;
Dfs(e[i].v,x);sum[x]+=sum[e[i].v];mx[x]=max(mx[x],sum[e[i].v]);
}
mx[x]=max(n-sum[x],mx[x]);
}
int ins(int nw,int x,int y,int pos) {
if(!nw) nw=++cnt,d[nw]=l[nw]=r[nw]=0;d[nw]++;
if(x==y) return nw;
int mid=x+y>>1;
if(pos<=mid) l[nw]=ins(l[nw],x,mid,pos);
else r[nw]=ins(r[nw],mid+1,y,pos);
return nw;
}
int merge(int a,int b,int x,int y) {
if(!a||!b) return a|b;
if(x==y) {
d[a]+=d[b];
return a;
}
int mid=x+y>>1;
l[a]=merge(l[a],l[b],x,mid);r[a]=merge(r[a],r[b],mid+1,y);
d[a]=d[l[a]]+d[r[a]];return a;
}
int query(int nw,int x,int y,int lx,int ry) {
if(!nw||lx>ry) return 0;
if(lx<=x&&ry>=y) return d[nw];
int mid=x+y>>1,h=0;
if(lx<=mid) h+=query(l[nw],x,mid,lx,ry);
if(ry>mid) h+=query(r[nw],mid+1,y,lx,ry);
return h;
}
void dfs(int x,int fa) {
rt[x]=ins(rt[x],1,n,sz[x]);
if(fa) add(sz[fa],-1),add(n-sz[x],1);
for(re int i=head[x];i;i=e[i].nxt)
if(e[i].v!=fa) dfs(e[i].v,x),rt[x]=merge(rt[x],rt[e[i].v],1,n);
if(mx[x]+mx[x]>n) {
int k=0;
if(mx[x]-t[x]>=2*mx[x]-n) k=ask(mx[x]-t[x])-ask(2*mx[x]-n-1);
if(mx[x]-t[x]<n-2*t[x]) k+=ask(n-2*t[x])-ask(mx[x]-t[x]);
k-=query(rt[x],1,n,2*mx[x]-n,mx[x]-t[x]);
k-=query(rt[x],1,n,mx[x]-t[x]+1,n-2*t[x]);
ans+=1ll*k*x;
}
if(fa) add(sz[fa],1),add(n-sz[x],-1);
}
void Dfs_(int x,int fa) {
sz[x]=1,t[x]=0;
for(re int i=head[x];i;i=e[i].nxt) {
if(e[i].v==fa) continue;
Dfs_(e[i].v,x);sz[x]+=sz[e[i].v];t[x]=max(t[x],sz[e[i].v]);
}
}
void DFs(int x,int fa,int cm) {
col[x]=cm;sz[x]=1;
for(re int i=head[x];i;i=e[i].nxt) {
if(e[i].v==fa) continue;
DFs(e[i].v,x,cm);sz[x]+=sz[e[i].v];
}
}
void solve(int Rt) {
int col_num=1,A=0,B=0;
for(re int i=head[Rt];i;i=e[i].nxt,col_num++) {
DFs(e[i].v,Rt,col_num);
if(sz[e[i].v]>=sz[A]) B=A,A=e[i].v;
else if(sz[e[i].v]>sz[B]) B=e[i].v;
}
for(re int i=1;i<=n;i++) {
if(i==Rt) continue;
if(col[i]!=col[A]&&2*sz[A]<=(n-sz[i])) ans+=Rt;
if(col[i]==col[A]&&2*max(sz[A]-sz[i],sz[B])<=(n-sz[i])) ans+=Rt;
}
}
int main() {
T=read();
for(re int Rt;T;--T) {
cnt=0;ans=0;n=read(),num=0,__=0;memset(head,0,sizeof(head));memset(rt,0,sizeof(rt));memset(c,0,sizeof(c));
for(re int x,y,i=1;i<n;i++) x=read(),y=read(),add_E(x,y),add_E(y,x);
Dfs(1,0);for(re int i=1;i<=n;i++) if(mx[i]+mx[i]<=n) Rt=i;
Dfs_(Rt,0);
for(re int i=1;i<=n;i++) add(sz[i],1);dfs(Rt,0);
for(re int i=1;i<=n;i++) if(mx[i]+mx[i]<=n) solve(i);
printf("%lld\n",ans);
}
return 0;
}