抄题解.jpg
发现原来的\(O(n^2)\)的换根\(dp\)好像行不通了呀
我们考虑非常牛逼的长链剖分
我们设\(f[x][j]\)表示在\(x\)的子树中距离\(x\)为\(j\)的点有多少个
\(g[x][j]\)表示在\(x\)的子树里,满足如下条件的点对\((u,v)\)的个数
-
设\(k=LCA(u,v)\),满足\(dis(u,k)=dis(v,k)=d\)
-
满足\(dis(k,x)=d-j\)
我们发现可以如果\(v\)是\(x\)的儿子,那么距离\(v\)为\(j-1\)的点和\(x\)的距离就是\(j\),那么到\(k\)的距离就是\(d-j+j=d\),和点对到\(k\)的距离相等
于是我们可以这样合并
自然还有
\(f\)数组的更新非常简单啊,就是\(f[x][j]+=f[v][j-1]\),这个我们可以用长链剖分优化到\(O(n)\)
之后是\(g\)的更新
首先我们有\(g[x][j]+=g[v][j+1]\),就是到\(x\)距离为\(d-j\)的\(k\)到\(v\)的距离必然是\(d-j-1\),这里我们也可以直接长链剖分
之后\(g[x][j+1]+=f[x][j+1]\times f[v][j]\),这样产生的点对的\(LCA\)就是\(x\),到\(x\)的距离也就是\(j+1\),符合条件,这里直接暴力转移就好了
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
const int maxn=100006;
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
struct E{int v,nxt;}e[maxn<<1];
int head[maxn],len[maxn],n,num,son[maxn],deep[maxn];
LL tax[maxn*6],*id=tax,*f[maxn],*g[maxn],ans;
inline void add(int x,int y) {
e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs1(int x) {
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]) continue;
deep[e[i].v]=deep[x]+1;
dfs1(e[i].v);
if(len[e[i].v]>len[son[x]]) son[x]=e[i].v;
}
len[x]=len[son[x]]+1;
}
void dfs(int x) {
f[x][0]=1;
if(son[x]) {
g[son[x]]=g[x]-1;
f[son[x]]=f[x]+1;
dfs(son[x]);
}
ans+=g[x][0];
for(re int i=head[x];i;i=e[i].nxt) {
if(deep[e[i].v]<deep[x]||son[x]==e[i].v) continue;
f[e[i].v]=id;id+=len[e[i].v]+1;
g[e[i].v]=id+len[e[i].v]+1;id+=2*len[e[i].v]+2;
dfs(e[i].v);
for(re int j=len[e[i].v];j>=0;--j) {
if(j) ans+=f[x][j-1]*g[e[i].v][j];
ans+=g[x][j+1]*f[e[i].v][j];
g[x][j+1]+=f[e[i].v][j]*f[x][j+1];
}
for(re int j=0;j<=len[e[i].v];j++) {
if(j) g[x][j-1]+=g[e[i].v][j];
f[x][j+1]+=f[e[i].v][j];
}
}
}
int main() {
n=read();
for(re int x,y,i=1;i<n;i++)
x=read(),y=read(),add(x,y),add(y,x);
deep[1]=1;dfs1(1);
f[1]=id;id+=len[1]+1;
g[1]=id+len[1]+1;//由于我们继承重儿子是g[son[x]]=g[x]-1,所以得在这个指针前面留一些空位置来让后面的状态继承
id+=2*len[1]+2;
dfs(1);printf("%lld\n",ans);
return 0;
}