https://gmoj.net/senior/#contest/show/2989/2

思考什么时候先手会赢。

一开始双方都不会希望走到直径的端点上,因为那样对方就可以走直径而使自己输掉。

删掉直径的端点,考虑剩下的树的子问题。

如果又走到端点去了,对面就走到另外一个端点,那我就走到下一层的直径端点去了。

所以大家一直都不想走到直径端点。

一直删,如果最后只剩1一个点,说明先手必败,否则先手必胜。

如果是一条链,就是链的两边的长度不等先手就必胜。

如果是一棵树,考虑随便找一条直径,每次删去它的两个端点。

1.这条直径不经过1,1会在中间被删掉,先手必胜;

2.这条直径经过1,则只有1是直径的中点时先手必败。

然后变成了一个dp:
\(f[i][j]\)表示\(i\)为根的子树里,选了含根联通块,到\(i\)的距离最大的距离是\(j\)

设\(md[x]\)表示\(x\)子树里的点到\(x\)的最大距离。

显然\(0<=j<=md[x]\)。

对树长链剖分,每次把一个短的子树合并过来。

假设长链的长度是p,合并来的链的长度是q,当前点是x。

对\(f[x][0-q+1]\)暴力求个前缀和来转移。

\(f[x][>q+1]\)是整体乘上一个数,再用一个数组来打lazytag。

最后合并时还需要一些细节,不过比较简单,前后缀扫一下即可。

pty教的实现\(f\)时用指针会很方便,见代码。

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;

const int mo = 998244353;

const int N = 1e6 + 5;

int n, x, y;
int fi[N], nt[N * 2], to[N * 2], tot;

void link(int x, int y) {
	nt[++ tot] = fi[x], to[tot] = y, fi[x] = tot;
}

int t[N], t0;

void Init() {
	scanf("%d", &n);
	fo(i, 1, n - 1) {
		scanf("%d %d", &x, &y);
		link(x, y); link(y, x);
	}
}

int md[N], fa[N], son[N];

void bfs() {
	t[t0 = 1] = 1;
	for(int i = 1; i <= t0; i ++) {
		int x = t[i];
		for(int j = fi[x]; j; j = nt[j]) if(to[j] != fa[x]) {
			fa[to[j]] = x;
			t[++ t0] = to[j];
		}
	}
	fd(i, t0, 1) {
		int x = t[i];
		for(int j = fi[x]; j; j = nt[j]) if(to[j] != fa[x]) {
			int y = to[j];
			md[x] = max(md[x], md[y] + 1);
			if(md[y] > md[son[x]]) son[x] = y;
		}
	}
}

ll fv[N * 2], *f[N], gv[N * 2], *g[N];

int d[N], d0, us;

void build() {
	fo(x, 1, n) if(son[fa[x]] != x) {
		d[d0 = 1] = x;
		for(int i = 1; i <= d0; i ++)
			if(son[d[i]]) d[++ d0] = son[d[i]];
		fo(i, 1, d0) f[d[i]] = fv + (us + i), g[d[i]] = gv + (us + i);
		us += d0 + 1;
	}
}

void xc(int x, int d) {
	fo(i, 0, d) {
		g[x][i + 1] = g[x][i + 1] * g[x][i] % mo;
		f[x][i] = f[x][i] * g[x][i] % mo;
		g[x][i] = 1;
	}
}

ll sa[N], sb[N];

void dp() {
	fd(i2, t0, 2) {
		int x = t[i2];
		f[x][0] = 1;
		for(int ii = fi[x]; ii; ii = nt[ii]) if(to[ii] != fa[x] && to[ii] != son[x]) {
			int y = to[ii];
			xc(x, md[y] + 1); xc(y, md[y]);
			sa[0] = sb[0] = 1;
			fo(i, 1, md[y] + 1) {
				sa[i] = (sa[i - 1] + f[x][i]) % mo;
				sb[i] = (sb[i - 1] + f[y][i - 1]) % mo;
			}
			f[x][0] = 1;
			fo(i, 1, md[y] + 1) f[x][i] = (sa[i] * sb[i] - sa[i - 1] * sb[i - 1] % mo + mo) % mo;
			g[x][md[y] + 2] = g[x][md[y] + 2] * sb[md[y] + 1] % mo;
		}
	}
}

ll p[N], q[N];
int ky[N];

int main() {
	freopen("tree.in", "r", stdin);
	freopen("tree.out", "w", stdout);
	Init();
	md[0] = -1;
	bfs();
	build();
	fo(i, 0, 2 * n) gv[i] = 1;
	dp();
	ll ans = 0;
	d0 = 0;
	for(int i = fi[1]; i; i = nt[i]) {
		int y = to[i];
		d[++ d0] = y;
		xc(y, md[y]);
	}
	ll s1 = 1;
	fo(w, 0, md[1] - 1) {
		p[0] = 1; q[d0 + 1] = 1;
		fo(i, 1, d0) {
			int x = d[i];
			p[i] = p[i - 1] * (w == 0 ? 1 : f[x][w - 1]) % mo;
		}
		fd(i, d0, 1) {
			int x = d[i];
			q[i] = q[i + 1] * (w == 0 ? 1 : f[x][w - 1]) % mo;
			ans = (ans + f[x][w] * p[i - 1] % mo * q[i + 1] % mo * s1) % mo;
			
			f[x][w] = ((w ? f[x][w - 1] : 1) + f[x][w]) % mo;
			ky[i] = md[x] > w;
		}
		int D = d0; d0 = 0;
		fo(i, 1, D) if(ky[i])
			d[++ d0] = d[i]; else s1 = s1 * f[d[i]][md[d[i]]] % mo;
	}
	ans = (ans % mo + mo) % mo;
	pp("%lld\n", ans);
}