2021牛客暑期多校第四场D.Rebuild Tree

题意

给一棵\(n\)个点的树,求删去其中\(k\)条边再加入\(k\)条边后仍然是一颗树的方案数。

\(2\le n\le 5\times10^4,1\le k\le \min(n-1,100)\)

题解

prufer序列

prufer序列是一种将一颗\(n\)个节点的有标号树用唯一的一个整数序列表示的方法。

一棵树的prufer序列构造过程如下:每次选择一个编号最小的叶结点并删掉它,然后在序列中记录下它连接到的那个结点编号。重复\(n-2\)次后就只剩下两个结点,算法结束。所以\(n\)个点的有标号树共有\(n^{n-2}\)种。

树上每个点的出现次数等于其对应的prufer序列中该数字出现次数加一

首先考虑删去的\(k\)条边已经确定时的答案,即为在形成的\(k+1\)个连通块之间连边形成树的方案数。不妨再假设这\(k+1\)个联通块之间的连边情况也已经确定。记连通块\(i\)的度数为\(s_i\),在对应prufer序列中的出现次数为\(P_i\),则此时答案为\(\prod_{i=1}^{k+1}s_i^{d_i}=\prod_{i=1}^{k+1}s_i^{P_i+1}=\prod_{i=1}^{k+1}s_i\prod_{i=1}^{k+1}s_i^{P_i}\),其中\(\prod_{i=1}^{k+1}s_i\)与连通块之间的连边情况无关。那么,\(k+1\)个连通块之间连边形成树的方案数为

 

\[\sum_{所有长度为k-1的prufer序列}\prod_{i=1}^{k+1}s_i\prod_{i=1}^{k+1}s_i^{P_i}=\prod_{i=1}^{k+1}s_i(\sum_{\sum_{i=1}^{k+1}P_i=k-1}\prod_{i=1}^{k+1}s_i^{P_i}) \]

 

考虑后一个式子的组合意义,相当于一个长度为\(n\)的序列,每个位置可以填入\(1...n\)的任何数,所以

 

\[\sum_{\sum_{i=1}^{k+1}P_i=k-1}\prod_{i=1}^{k+1}s_i^{P_i}=n^{k-1} \]

 

OI-wiki上关于这个结论的证明
所以这道题的需要计算的答案为

 

\[\sum_{将原树分成k+1个连通块}n^{k-1}\prod_{i=1}^{k+1}s_i=n^{k-1}(\sum_{将原树分成k+1个连通块}\prod_{i=1}^{k+1}s_i) \]

 

考虑后一个式子的组合意义,相当于将原树去掉\(k\)条边然后在每个连通块中选一个点的方案数,而这可以用dp计算,具体细节见代码。

#include <bits/stdc++.h>
#define pb(x) emplace_back(x) 
using namespace std;
using ll=long long ;
const int N=50005;
const ll M=998244353;
int n,k,sz[N],b[N];
ll f[N][102][2],g[105][2];
void MOD(ll&x){x%=M;}
ll pm(ll x,ll b){x%=M;ll res=1;while(b){if(b&1)res=res*x%M;x=x*x%M;b>>=1;}return res;}
vector<int> e[N];
//f[u][i][0/1] 表示u的子树内删了i条边,u所在的联通块没选/选了点的方案数 
void dfs(int u,int fa){
	sz[u]=1;
	f[u][0][1]=f[u][0][0]=1;
	for(auto v:e[u])if(v!=fa){
		dfs(v,u); 
		memset(g,0,sizeof(g));
		for(int i=0;i<sz[u]&&i<=k;i++){//此时最多sz[u]-1条边 
			for(int j=0;j<sz[v]&&i+j<=k;j++){//最多sz[v]-1条边
				//不删(u,v)这条边 
				g[i+j][0]+=f[u][i][0]*f[v][j][0]%M;
				MOD(g[i+j][0]);
				g[i+j][1]+=(f[u][i][1]*f[v][j][0]%M+f[u][i][0]*f[v][j][1]%M)%M;
				MOD(g[i+j][1]);
				if(i+j<k){
					//删掉(u,v)这条边,此情况下v所在的联通块必须已经选点 
					g[i+j+1][0]+=f[u][i][0]*f[v][j][1]%M;
					MOD(g[i+j+1][0]);
					g[i+j+1][1]+=f[u][i][1]*f[v][j][1]%M;
					MOD(g[i+j+1][1]);
				} 
			}
		}
		memcpy(f[u],g,sizeof(g));
		sz[u]+=sz[v];
	} 
}
void f1(){
	cin>>n>>k;
	for(int i=1;i<n;i++){
		int x,y;cin>>x>>y;
		e[x].pb(y);e[y].pb(x); 
	}
	dfs(1,0);
	ll ans=pm(n,k-1)*f[1][k][1]%M;
	cout<<ans;
}
int main(){
	f1();
	return 0;
}