\(E.Wandering TKHS\\\)
题目链接:AGC029E
题意简述:给你一个\(n\)个点的树,一开始你有一个集合\(S\),包含一个元素\(r\),然后你每次会选择所有与集合\(S\)中的元素在树上相邻的点中最小的,然后加入\(S\),直到\(1 \in S\),定义此时的\(|S| = c(r)\),求\(\forall r \in[2,n],c(r)\)
\(n \leq 2e5\)
稍微好想一点的神仙题,本篇题解实际上约等于对官方题解的翻译
考虑模拟题目这一过程,以\(1\)为根,注意到\(r\)要到达\(1\),一定要经过它祖先的所有点,其中比较特殊的点是最大的那个,定义为\(M_r\)
对于这种题,我们考虑增量法,思考\(c(fa_r)与c(r)\)的关系
定义\(gfa(r) = fa(fa(r)),Q(x,v)\)为以\(x\)为根的子树中只经过\(\leq v\)的点能扩展的大小
分情况讨论,若\(M_{gfa(r)} < r\),那么意味着若集合初始点是\(fa(r)\),它永远不会经过\(r\),那么有\(c(r) = c(fa_r) + Q(r,M(fa_r))\)
否则的话,\(fa_r\)扩展的过程会经过\(r\),一般来说\(c(fa_r) = c(r)\),除非有一种特殊情况,即\(M(fa_r) = fa_r\),那么此时\(c(r) = c(fa_r) - Q(r,M(gfa_r))+ Q(r,M(fa_r))\),注意到即使\(M(fa_r) != fa_r\),上式依然成立,因为此时\(M(gfa_r) = M(fa_r)\)
考虑如何计算\(Q(r,M(fa_r))和Q(r,M(gfa_r))\),首先显然可以线段树合并,但过于麻烦,考虑以下这种记忆化搜索
首先注意看转移
int ans = 1; for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; if(v ^ par[u] && v <= val) ans += solveQ(v,val); }
- 注意到若\(fa_r = M(fa_r)\),则\(gra_r\)在计算\(Q\)值时便不会计算到\(fa_r\)
- 否则的话\(M(fa_r) = M(gfa_r)\),并不会多计算
- 所以需要存的\(Q\)的状态很少,记忆化搜索即可
- 代码如下:
#includeusing namespace std; #define ll long long const int _ = 2e5 + 7; int read(){ char c = getchar(); int x = 0; while(c < '0' || c > '9') c = getchar(); while(c >= '0' && c <= '9') x = x * 10 + c - 48,c = getchar(); return x; } mapQ[_]; int c[_],par[_],M[_]; struct Edge{ int nxt,point; }edge[_<<1]; int head[_],tot; void add_edge(int u,int v){ edge[++tot].nxt = head[u]; edge[tot].point = v; head[u] = tot; } int n; void dfs(int u,int fa) { par[u] = fa; M[u] = max(M[fa],u); for(int i = head[u]; i; i = edge[i].nxt){ int v = edge[i].point; if(v ^ fa){ dfs(v,u); } } } int solveQ(int u,int val){ if(Q[u].count(val)) return Q[u][val]; int ans = 1; for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; if(v ^ par[u] && v <= val) ans += solveQ(v,val); } Q[u][val] = ans; return ans; } void solvec(int u){ if(u == 1) c[u] = 0; else if(M[par[par[u]]] < u) c[u] = c[par[u]] + Q[u][M[par[u]]]; else c[u] = c[par[u]] + Q[u][M[par[u]]] - Q[u][M[par[par[u]]]]; for(int i = head[u]; i ; i = edge[i].nxt){ int v = edge[i].point; if(v ^ par[u]) solvec(v); } } int main(){ n = read(); for(int i = 1; i < n; ++i){ int u = read(),v = read(); add_edge(u,v);add_edge(v,u); } dfs(1,0); for(int r = 2; r <= n; ++r){ solveQ(r,M[par[r]]); int G = par[par[r]] ; if(G) solveQ(r,M[G]); } solvec(1); for(int i = 2; i <= n; ++i) cout<<c[i]<<' '; return 0; }