3611: [Heoi2014]大工程
Time Limit: 60 Sec Memory Limit: 512 MBSubmit: 1945 Solved: 811
[Submit][Status][Discuss]
Description
Input
第一行 n 表示点数。
Output
输出 q 行,每行三个数分别表示代价和,最小代价,最大代价。
Sample Input
2 1
3 2
4 1
5 2
6 4
7 5
8 6
9 7
10 9
5
2
5 4
2
10 4
2
5 2
2
6 1
2
6 1
Sample Output
6 6 6
1 1 1
2 2 2
2 2 2
HINT
n<=1000000
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; typedef long long ll; const ll maxn = 1000010,inf = 1e17; ll n,head[maxn],to[maxn * 2],nextt[maxn * 2],tot = 1,deep[maxn],fa[maxn][22]; ll pos[maxn],dfs_clock,Q,sta[maxn],top,a[maxn],b[maxn],k,sizee[maxn],flag[maxn],Time; ll maxx[maxn],minn[maxn],ans1,ans2,ans; void add(ll x,ll y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void add2(ll x,ll y) { if (x == y) return; to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } ll lca(ll x,ll y) { if (deep[x] < deep[y]) swap(x,y); for (ll i = 21; i >= 0; i--) if (deep[fa[x][i]] >= deep[y]) x = fa[x][i]; if (x == y) return x; for (int i = 21; i >= 0; i--) if (fa[x][i] != fa[y][i]) { x = fa[x][i]; y = fa[y][i]; } return fa[x][0]; } void dfs(ll u,ll faa) { fa[u][0] = faa; deep[u] = deep[faa] + 1; pos[u] = ++dfs_clock; for (ll i = head[u]; i; i = nextt[i]) { ll v = to[i]; if (v == faa) continue; dfs(v,u); } } bool cmp(ll x,ll y) { return pos[x] < pos[y]; } void dp(ll u) { if (flag[u] == Time) { sizee[u] = 1; maxx[u] = minn[u] = 0; } else { sizee[u] = 0; maxx[u] = -inf; minn[u] = inf; } for (ll i = head[u]; i; i = nextt[i]) { ll v = to[i],w = deep[v] - deep[u]; dp(v); sizee[u] += sizee[v]; ans1 = min(ans1,minn[v] + minn[u] + w); ans2 = max(ans2,maxx[v] + maxx[u] + w); minn[u] = min(minn[u],minn[v] + w); maxx[u] = max(maxx[u],maxx[v] + w); ans += w * sizee[v] * (k - sizee[v]); } head[u] = 0; } void solve() { ++Time; ans = 0; ans1 = inf; ans2 = -inf; scanf("%lld",&k); for (ll i = 1; i <= k; i++) { scanf("%lld",&a[i]),b[i] = a[i]; flag[a[i]] = Time; } sort(a + 1,a + 1 + k,cmp); top = 0; tot = 1; sta[++top] = 1; for (ll i = 1; i <= k; i++) { ll LCA = lca(a[i],sta[top]); while (1) { if (deep[sta[top - 1]] <= deep[LCA]) { add2(LCA,sta[top]); top--; if (sta[top] != LCA) sta[++top] = LCA; break; } add2(sta[top - 1],sta[top]); top--; } if(sta[top] != a[i]) sta[++top] = a[i]; } top--; while (top) { add2(sta[top],sta[top + 1]); top--; } dp(1); printf("%lld %lld %lld\n",ans,ans1,ans2); } int main() { scanf("%lld",&n); for (ll i = 1; i < n; i++) { ll x,y; scanf("%lld%lld",&x,&y); add(x,y); add(y,x); } dfs(1,0); for (ll j = 1; j <= 21; j++) for (ll i = 1; i <= n; i++) fa[i][j] = fa[fa[i][j - 1]][j - 1]; memset(head,0,sizeof(head)); tot = 1; scanf("%lld",&Q); while (Q--) solve(); return 0; }