题目
Description
有一个挖宝游戏,它在一棵 n 个点的树上进行,宝藏埋在某个未知的点 ????。每次挖掘一个点 u,玩家得到的反馈信息是一个数值 d,
表示 u 号点到 ???? 号点简单路径上的边数。这个游戏会进行 q 次,每次游戏藏宝的位置不一定相同。
你作为一名优秀的 ????????er,对自己无比自信。你希望用最少的挖掘次数来找出宝藏。于是你挑了两个不同的点 a,b 进行挖掘,并得
到了反馈信息,分别为 。接下来的第三次挖掘中,你想要直接奔着一个可能的 ???? 进行挖掘。由于树太大了,凭借人眼无法找
出 ???? 的确切位置,你便转向了电脑,开始写一个程序,帮助你解决这个问题。
Input
第一行输入两个正整数 n,q,表示树的点数和游戏次数。
接下来 n − 1 行,每行输入两个正整数 u,v,描述一条树边。保证输入的是一棵树。
接下来 q 行,每行输入四个正整数 表示一次游戏进行两次挖掘得到的反馈信息。
Output
输出 q行。每行包含一个整数表示一个满足条件的藏宝地点,或者输出 −1 表示无解。当有多个可能的藏宝地点时,输出任意一个即可。
Sample Input
【样例输入】
5 3
1 2
2 3
3 4
3 5
2 1 4 1
2 2 4 2
1 1 2 1
Sample Output
【样例输出】
3
5
-1
Data Constraint
思路
首先要判断这个点是否存在,这很容易判
然后我们分四种情况讨论
- u 在 a 子树中。需要保证深度足够
- u 在 b 子树中。同1
- u 在 路径(a,b)上的某个点的子树内,这相当于走了一段然后下去
- u 在 路径(a,b)外面的子树,相当于先走到lca再下去
代码
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+77;
int a[N],b[N],dep[N],n,q,cnt,nxt[N<<1],to[N<<1],ls[N],c[N],d[N],pos[N<<1],num,logn[N<<1],st[N<<1][21],g[N];
int dfn[N],f[N][20],ans,s3[N],t3[N];
void add(int x,int y)
{
nxt[++num]=ls[x]; ls[x]=num; to[num]=y;
nxt[++num]=ls[y]; ls[y]=num; to[num]=x;
}
void dfs1(int u,int pa)
{
dfn[u]=++cnt;
dep[u]=dep[pa]+1;
pos[cnt]=u; a[u]=b[u]=c[u]=d[u]=u;
for(int i=0; i<19; i++) f[u][i+1]=f[f[u][i]][i];
for(int i=ls[u]; i; i=nxt[i])
{
int v=to[i];
if(v==pa) continue;
f[v][0]=u;
dfs1(v,u);
pos[++cnt]=u;
if(dep[c[v]]>dep[c[u]])
{
t3[u]=d[u];
d[u]=c[u];
c[u]=c[v];
s3[u]=b[u];
b[u]=a[u];
a[u]=v;
}
else if(dep[c[v]]>dep[d[u]])
{
t3[u]=d[u];
d[u]=c[v];
s3[u]=b[u];
b[u]=v;
}
else if(dep[c[v]]>dep[t3[u]])
{
t3[u]=c[v];
s3[u]=v;
}
}
}
void init()
{
logn[0]=-1;
for(int i=1; i<=cnt; i++)
{
logn[i]=logn[i>>1]+1;
st[i][0]=pos[i];
}
for(int j=1; (1<<j)<=cnt; j++) for(int i=1; i+(1<<j)-1<=cnt; i++)
{
int u=st[i][j-1],v=st[i+(1<<j-1)][j-1];
if(dep[u]<dep[v]) st[i][j]=u;
else st[i][j]=v;
}
}
int lca(int l,int r)
{
l=dfn[l]; r=dfn[r];
if(l>r) swap(l,r);
int k=logn[r-l+1],u=st[l][k],v=st[r-(1<<k)+1][k];
return dep[u]<dep[v]?u:v;
}
int dist(int x,int y)
{
int z=lca(x,y);
return dep[x]+dep[y]-(dep[z]<<1);
}
int jump(int x,int d)
{
for(int i=19; i>=0; i--) if(d&(1<<i)) x=f[x][i];
return x;
}
int calc(int x,int y,int d)
{
int z=lca(x,y),dis=dep[x]+dep[y]-(dep[z]<<1);
if(d<=dep[x]-dep[z]) return jump(x,d);
else return jump(y,dis-d);
}
void dfs2(int u,int pa)
{
for(int i=ls[u]; i; i=nxt[i])
{
int v=to[i];
if(v==pa) continue;
g[v]=g[u];
if(a[u]==v)
{
if(dist(d[u],v)>dist(g[v],v)) g[v]=d[u];
}
else
{
if(dist(c[u],v)>dist(g[v],v)) g[v]=c[u];
}
dfs2(v,u);
}
}
int solve12(int u,int v,int du,int dv,int z)
{
int dis=dep[u]+dep[v]-(dep[z] << 1);
if(dv==du+dis)
{
if(z==u)
{
int fv=calc(u,v,1);
if(fv!=a[u])
{
if(dep[c[u]]-dep[u]>=du) return calc(u,c[u],du);
}
else if(dep[d[u]]-dep[u]>=du) return calc(u,d[u],du);
}
else if(dep[c[u]]-dep[u]>=du) return calc(u,c[u],du);
}
if(du==dv+dis)
{
if(z==v)
{
int fu=calc(v,u,1);
if(fu!=a[v])
{
if(dep[c[v]]-dep[v]>=dv) return calc(v,c[v],dv);
}
else if(dep[d[v]]-dep[v]>=dv) return calc(v,d[v],dv);
}
else if(dep[c[v]]-dep[v]>=dv) return calc(v,c[v],dv);
}
return -1;
}
int solve3(int u,int v,int du,int dv,int z)
{
int dis=dep[u]+dep[v]-(dep[z]<<1);
if(du+dv>=dis)
{
int del = du - dv;
if((dis+del)&1) return -1;
int a1=dis+del>>1,b1=dis-a1,y=calc(u,v,a1),tot=du-a1;
if(a1<0||b1<0) return -1;
if(y==u||y==v) return -1;
if(y==z)
{
int fu=jump(u,dep[u]-dep[z]-1),fv=jump(v,dep[v]-dep[z]-1);
if(a[z] != fu && a[z] != fv)
{
if(dep[c[z]]-dep[z]>=tot) return calc(z,c[z],tot);
}
else if(b[z]!=fu&&b[z]!=fv)
{
if(dep[d[z]]-dep[z]>=tot) return calc(z,d[z],tot);
}
else
{
if(dep[t3[z]]-dep[z]>=tot) return calc(z,t3[z],tot);
}
}
else if(a1<=dep[u]-dep[z])
{
int ys=calc(u,v,a1-1);
if(ys==a[y])
{
if(dep[d[y]]-dep[y]>=tot) return calc(y,d[y],tot);
}
else
{
if(dep[c[y]]-dep[y]>=tot) return calc(y,c[y],tot);
}
}
else
{
int ys=calc(u,v,a1+1);
if(ys==a[y])
{
if(dep[d[y]]-dep[y]>=tot) return calc(y,d[y],tot);
}
else
{
if(dep[c[y]]-dep[y]>=tot) return calc(y,c[y],tot);
}
}
}
return -1;
}
int solve4(int u,int v,int du,int dv,int z)
{
int dis = dep[u] + dep[v] - (dep[z] << 1);
if(du - dv == dep[u] - dep[v] && dep[u] - dep[z] <= du && dep[v] - dep[z] <= dv)
{
int dc = du - (dep[u] - dep[z]);
if(dist(z,g[z]) >= dc) return calc(z,g[z],dc);
}
return -1;
}
int main()
{
freopen("hunting.in","r",stdin); freopen("hunting.out","w",stdout);
scanf("%d%d",&n,&q);
for(int i=1,x,y; i<n; i++) scanf("%d%d",&x,&y),add(x,y);
dfs1(1,0);
init();
g[1]=1;
dfs2(1,0);
while(q--)
{
int u,du,v,dv;
scanf("%d%d%d%d",&u,&du,&v,&dv);
if(u==v&&du!=dv)
{
printf("-1\n"); continue;
}
ans=-1;
if(u==v)
{
if(dist(u,g[u])>=du) ans=calc(u,g[u],du);
else if(dist(u,c[u])>=du) ans=calc(u,c[u],du);
printf("%d\n",ans);
continue;
}
int z=lca(u,v);
ans=solve12(u,v,du,dv,z);
if(ans==-1) ans=solve3(u,v,du,dv,z);
if(ans==-1) ans=solve4(u,v,du,dv,z);
printf("%d\n",ans);
}
}