1.​​题目链接​​。给定一棵有n个点的树,询问树上距离为k的点对是否存在。树上距离定义为两点之间的路径和。(可以知道,树上两点路径是唯一的,因为不存在环)。点分治的模板题,基本思想就是通过选取一个点,把树分割成不同的子树,然后在子树里解决,采用分治思想,之所以叫做点分治,因为是以点作为标准来实现分治。本题首先找到树的重心,这个一遍dfs即可找到,然后从这里开始对这个点的每颗子树分治,分治的过程中,rem[i]代表当前点到当前选择的根节点的距离,然后在处理第i个子树时,我们需要知道前i-1个子树的信息,用一个judge数组保存一下,judge[i]表示前i-1棵子树里面距离根节点距离为i是否存在,然后枚举当前子树所有的点,通过judge数组匹配。完成之后,分治下去,直到遍历所有的点。时间复杂度N*logN

#include<bits/stdc++.h>
using namespace std;
const int inf = 10000000;
const int maxn = 100010;
int n, m;
struct node { int v, dis, nxt; }E[maxn << 1];
int tot, head[maxn];
int maxp[maxn], siz[maxn], dis[maxn], rem[maxn];
int vis[maxn], test[inf], judge[inf], q[maxn];
int query[1010];
int sum, rt;
int ans;

void add(int u, int v, int dis)
{
E[++tot].nxt = head[u];
E[tot].v = v;
E[tot].dis = dis;
head[u] = tot;
}

void getrt(int u, int pa)
{
siz[u] = 1;
maxp[u] = 0;
for (int i = head[u]; i; i = E[i].nxt)
{
int v = E[i].v;
if (v == pa || vis[v])continue;
getrt(v, u);
siz[u] += siz[v];
maxp[u] = max(maxp[u], siz[v]);
}
maxp[u] = max(maxp[u], sum - siz[u]);
if (maxp[u] < maxp[rt])rt = u;
}



void getdis(int u, int fa)
{
rem[++rem[0]] = dis[u];
for (int i = head[u]; i; i = E[i].nxt)
{
int v = E[i].v;
if (v == fa || vis[v])continue;
dis[v] = dis[u] + E[i].dis;
getdis(v, u);
}
}


void calc(int u)
{
int p = 0;
for (int i = head[u]; i; i = E[i].nxt)
{
int v = E[i].v;
if (vis[v])continue;
rem[0] = 0;
dis[v] = E[i].dis;
getdis(v, u);

for (int j = rem[0]; j; --j)
for (int k = 1; k <= m; ++k)
if (query[k] >= rem[j])
test[k] |= judge[query[k] - rem[j]];

for (int j = rem[0]; j; --j)
q[++p] = rem[j], judge[rem[j]] = 1;
}
for (int i = 1; i <= p; ++i)
judge[q[i]] = 0;
}

void solve(int u)
{
vis[u] = judge[0] = 1;
calc(u);
for (int i = head[u]; i; i = E[i].nxt)//对每个子树进行分治
{
int v = E[i].v;
if (vis[v])continue;
sum = siz[v];
maxp[rt = 0] = inf;
getrt(v, 0);
solve(rt);//在子树中找重心并递归处理
}
}

int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i < n; ++i)
{
int u, v, dis;
scanf("%d%d%d", &u, &v, &dis);
add(u, v, dis);
add(v, u, dis);
}
for (int i = 1; i <= m; ++i)
scanf("%d", &query[i]);

maxp[rt] = sum = n;
getrt(1, 0);
solve(rt);

for (int i = 1; i <= m; ++i)
{
if (test[i]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}