题目链接:​​传送门​

像​​这道题​​​一样
先算出<=的,再算出<的
一减就是等于的了

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <complex>
#include <algorithm>
#include <climits>
#include <queue>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define
#define

using namespace std;
typedef long long ll;
struct node {
int next, to, w;
}e[A << 1];
int head[A << 1], num;
void add(int fr, int to, int w) {
e[++num].next = head[fr];
e[num].to = to;
e[num].w = w;
head[fr] = num;
}
int n, k, G, a, b, c, tot/*遍历的这棵树的大小*/; ll ans;
int siz[A], mx[A]/*子树大小最大值*/, dep[A], dis[A], vis[A], f[A];
void getG(int fr, int fa) {
mx[fr] = 1; siz[fr] = 1;
for (int i = head[fr]; i; i = e[i].next) {
int ca = e[i].to;
if (ca == fa or vis[ca]) continue;
getG(ca, fr);
siz[fr] += siz[ca];
mx[fr] = max(mx[fr], siz[ca]);
}
mx[fr] = max(mx[fr], tot - siz[fr]); //父亲翻转为子树
G = mx[fr] < mx[G] ? fr : G;
}
void dfs(int fr, int fa) {
dep[++dep[0]] = f[fr];
for (int i = head[fr]; i; i = e[i].next) {
int ca = e[i].to;
if (ca == fa or vis[ca]) continue;
f[ca] = f[fr] + e[i].w;
dfs(ca, fr);
}
}
int calc(int fr, int len, int sum = 0, int sum2 = 0) {
dep[0] = 0; f[fr] = len; dfs(fr, 0);
sort(dep + 1, dep + 1 + dep[0]);
for (int l = 1, r = dep[0]; l < r;)
if (dep[l] + dep[r] <= k) sum += r - l, l++;
else r--;
for (int l = 1, r = dep[0]; l < r;)
if (dep[l] + dep[r] < k) sum2 += r - l, l++;
else r--;
return sum - sum2;
}
void divide(int fr) {
ans += calc(fr, 0); vis[fr] = 1;
for (int i = head[fr]; i; i = e[i].next) {
int ca = e[i].to;
if (vis[ca]) continue;
ans -= calc(ca, e[i].w); //经过重心下面的点就统计答案,要减掉
tot = mx[0] = siz[ca]; //新树的大小,即子问题
G = 0; getG(ca, 0); divide(G);
}
}
inline char GETCHAR() {
static char buf[B], *p1 = buf, *p2 = buf;
return p1 == p2 and (p2 = (p1 = buf) + fread(buf, 1, B, stdin), p1 == p2) ? EOF : *p1++;
}
template<class T> void read(T &x) {
x = 0; char ch = GETCHAR();
while (!isdigit(ch)) ch = GETCHAR();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = GETCHAR();
}

int main(int argc, char const *argv[]) {
scanf("%d%d", &n, &k);
for (int i = 1; i < n; i++) {
read(a); read(b); c = 1;
add(a, b, c); add(b, a, c);
}
tot = mx[0] = n; getG(1, 0);
divide(G); printf("%d\n", ans);
return 0;
}