题目

现在有一棵树,小AA想会选一个点出发,沿着树上的边不回头的行走。

小AA不喜欢走太久,所以他最多只会行走kk条边。

小AA也不喜欢重复,所以他想要知道最多有多少种行走方式。

两个方式不同当且仅当经过的边集不同。

输入格式
第一行两个整数,n,kn,k,nn表示树的节点个数。

第二行一共n−1n−1个整数,第ii个数fi+1fi+1表示i+1i+1的父亲。

输出格式
一行,一个整数,表示行走方案数。

样例
Input
5 2
1 1 2 2
Output
8
数据规模与约定
对于10%10%的数据,满足n,m≤100n,m≤100;

对于30%30%的数据,满足n,m≤1000n,m≤1000;

对于50%50%的数据,满足n,m≤20000n,m≤20000;

对于80%80%的数据,满足n,m≤105n,m≤105;

对于100%100%的数据,满足n,m≤106,fi<in,m≤106,fi<i;

时间限制:1s

空间限制:64MB

思路

看起来像点分治板子题,但是空间只有64M(虽然空间理论是O(n),但是运存不知道为什么大得离谱),并且点分治时间常数较大

考虑长链剖分+dsu on tree
设f[i][j]为以i为根的子树,距离点i为j的点数个数,g[i][j]为f[i][j]的后缀和
转移时用f[v][j]*(g[u][0]-g[u][j-k])更新ans,f[v][j]更新f[v][j+1]
发现f可以省掉

代码

#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define pb push_back
#define mp make_pair
using namespace std;
const int N=1e6+77;
int n,k,f[N],ls[N],dep[N],nxt[N],dfn[N],son[N],cnt;
int q,X,aa[N],sum[N],K[N],B[N],Q[N*30],c[N],bot,a[N];
ll ans;
struct Node
{
	int l,r,st,en;
}T[N<<2];
bool cmp(int x,int y)
{
	return K[x]>K[y]||(K[x]==K[y]&&B[x]<B[y]);
}
double pos(int x,int y) 
{
	return (double)(B[y]-B[x]) / (K[x]-K[y]); 
}
void baoli()
{
	scanf("%d%d",&n,&q);
	for(int i=1; i<=n; i++) scanf("%lld",&a[i]),sum[i]=sum[i-1]+a[i];
	while(q--)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		ll aii=0x3f3f3f3f,yjy=0x3f3f3f3f;
		for(int i=y; i>=max(y-x,1); i--)
		{
			aii=min(aii,1ll*a[i]); 
			yjy=min(yjy,aii*(x-y+i-1)+sum[y]-sum[i-1]);
		}
		printf("%lld\n",yjy);
	}
}
void build(int i,int l,int r)
{
	T[i].l=l;
	T[i].r=r;
	int t=0,now=T[i].st=bot+1;
	for(int j=l; j<=r; j++)
		c[++t]=j;
	sort(c+1,c+t+1,cmp);
	Q[++bot]=c[1];
	for(int j=2; j<=t; j++)
		if(K[c[j]]<K[c[j-1]])
		{
			while(bot>now&&pos(Q[bot-1],Q[bot])>pos(Q[bot],c[j]))
				bot--;
			Q[++bot]=c[j];
		}
	T[i].en=bot;
	if(l==r)
		return;
	int M=l+r >> 1;
	build(i<<1,l,M);
	build(i<<1|1,M+1,r);
}

void query(int l,int r)
{
	while(l<r)
	{
		int M=(l+r >> 1)+1;
		if(X >= pos(Q[M-1],Q[M]))
			l=M;
		else
			r=M-1; 
	}
	ans=min(ans,K[Q[l]]*X+B[Q[l]]*1ll);
}

void query(int i,int l,int r)
{
	if(l<=T[i].l&&T[i].r<=r)
	{
		query(T[i].st,T[i].en);
		return;
	}
	int M=T[i].l+T[i].r >> 1;
	if(l<=M)
		query(i<<1,l,r);
	if(r>M)
		query(i<<1|1,l,r);
}

void dfs(int u)
{
	for(int v=ls[u];v;v=nxt[v])
	{
		dfs(v);
		if(dep[v]>dep[u]) dep[u]=dep[v],son[u]=v;
	}
	dep[u]++;
}
int calc(int x,int l,int r)
{
	return f[dfn[x]+l]-((r+1<dep[x])?f[dfn[x]+r+1]:0);
}
void dfs2(int u)
{
	dfn[u]=++dfn[0];
	if(son[u])
	{
		dfs2(son[u]);
		ans += calc(son[u],0,k-1);
		f[dfn[u]]=f[dfn[u]+1]+1;
	}else f[dfn[u]]=1;
	for(int v=ls[u];v;v=nxt[v])
		if(v != son[u])
			dfs2(v);
	for(int v=ls[u];v;v=nxt[v])
	{
		if(v != son[u])
		{
			for(int i=0; i<=min(dep[v],k)-1; i++) ans+=1ll*calc(v,i,i)*calc(u,0,k-i-1);
			for(int i=0; i<=dep[v]-1; i++) f[dfn[u]+i+1] += f[dfn[v]+i];
			f[dfn[u]] += f[dfn[v]];
		}
	}
}
int main()
{
 	scanf("%d%d",&n,&k);
 	for(int i=2; i<=n; i++)
	{
 		int x;
 		scanf("%d",&x);
 		nxt[i]=ls[x];
 		ls[x]=i;
 	}
 	dfs(1);
 	dfs2(1);
 	printf("%lld",ans);
}