和之前那个 【LNOI】LCA 几乎是同一道题,就是用动态树来维护查分就行.
code:
#include <bits/stdc++.h> using namespace std; #define N 50006 #define mod 998244353 #define ll long long #define lson t[x].ch[0] #define rson t[x].ch[1] #define get(x) (t[t[x].f].ch[1]==x) #define isrt(x) (!(t[t[x].f].ch[0]==x||t[t[x].f].ch[1]==x)) #define setIO(s) freopen(s".in","r",stdin) int sta[N],hd[N],to[N],nex[N],answer[N],dep[N],n,Q,K,edges; inline void add(int u,int v) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v; } inline int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } struct sol { int y,id; sol(int y=0,int id=0):y(y),id(id){} }; vector<sol>a[N]; struct node { int f,rev,ch[2],add; ll sum1,sum2,val1,val2; }t[N]; inline void pushup(int x) { t[x].sum1=(t[lson].sum1+t[rson].sum1+t[x].val1)%mod; t[x].sum2=(t[lson].sum2+t[rson].sum2+t[x].val2)%mod; } inline void rotate(int x) { int old=t[x].f,fold=t[old].f,which=get(x); if(!isrt(old)) t[fold].ch[t[fold].ch[1]==old]=x; t[old].ch[which]=t[x].ch[which^1],t[t[old].ch[which]].f=old; t[x].ch[which^1]=old,t[old].f=x,t[x].f=fold; pushup(old),pushup(x); } inline void mark(int x,int d) { (t[x].val2+=1ll*d*t[x].val1%mod)%=mod; (t[x].sum2+=1ll*d*t[x].sum1%mod)%=mod; t[x].add+=d; } inline void pushdown(int x) { if(x&&t[x].add) { if(lson) mark(lson,t[x].add); if(rson) mark(rson,t[x].add); t[x].add=0; } } void splay(int x) { int v=0,u=x,fa; for(sta[++v]=u;!isrt(u);u=t[u].f) sta[++v]=t[u].f; for(;v;--v) pushdown(sta[v]); for(u=t[u].f;(fa=t[x].f)!=u;rotate(x)) { if(t[fa].f!=u) { rotate(get(fa)==get(x)?fa:x); } } } void Access(int x) { for(int y=0;x;y=x,x=t[x].f) { splay(x); rson=y; pushup(x); } } void dfs(int u) { dep[u]=dep[t[u].f]+1; t[u].val1=(qpow(dep[u],K)-qpow(dep[u]-1,K)+mod)%mod; for(int i=hd[u];i;i=nex[i]) dfs(to[i]); pushup(u); } int main() { // setIO("input"); int i,j; scanf("%d%d%d",&n,&Q,&K); for(i=2;i<=n;++i) { scanf("%d",&t[i].f),add(t[i].f,i); } dep[1]=1,dfs(1); for(i=1;i<=Q;++i) { int x,y; scanf("%d%d",&x,&y); a[x].push_back(sol(y,i)); } for(i=1;i<=n;++i) { Access(i),splay(i),mark(i,1); for(j=0;j<a[i].size();++j) { Access(a[i][j].y),splay(a[i][j].y); answer[a[i][j].id]=t[a[i][j].y].sum2%mod; } } for(i=1;i<=Q;++i) printf("%d\n",answer[i]); return 0; }