题目:https://www.luogu.org/problemnew/show/P3398
树链剖分一下,路径就变成线段树上的几个区间;
两条路径相交就是线段树上有区间相交,所以在相应位置打个标记,查询有无标记即可;
一开始是打1的标记,查询后就减去,查询 sum 是否为 0 即可;
然而这样写却全 WA 了...悲痛欲绝去看了 TJ ,模仿了其写法,回头再看发现是忘记写 pushdown ,而且 -1 的地方写成 0 了囧...
改掉就 A 了,这个做法完全没问题嘛!
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; int const maxn=1e5+5; int n,q,hd[maxn],ct,sum[maxn<<2],tim,fa[maxn],lzy[maxn<<2]; int dep[maxn],dfn[maxn],top[maxn],son[maxn],siz[maxn],id[maxn]; struct N{ int to,nxt; N(int t=0,int n=0):to(t),nxt(n) {} }ed[maxn<<1]; void add(int x,int y){ed[++ct]=N(y,hd[x]); hd[x]=ct;} void dfs(int x,int f) { dep[x]=dep[f]+1; siz[x]=1; fa[x]=f; for(int i=hd[x],u;i;i=ed[i].nxt) { if((u=ed[i].to)==f)continue; dfs(u,x); siz[x]+=siz[u]; if(siz[u]>siz[son[x]])son[x]=u; } } void dfs2(int x) { dfn[x]=++tim; id[tim]=x;// if(son[x])top[son[x]]=top[x],dfs2(son[x]); for(int i=hd[x],u;i;i=ed[i].nxt) { if((u=ed[i].to)==fa[x]||u==son[x])continue; top[u]=u; dfs2(u); } } void pushup(int x){sum[x]=sum[x<<1]+sum[x<<1|1];} void pushdown(int x,int l,int r) { if(!lzy[x])return; int ls=(x<<1),rs=(x<<1|1),mid=((l+r)>>1); lzy[ls]+=lzy[x]; lzy[rs]+=lzy[x]; sum[ls]+=lzy[x]*(mid-l+1); sum[rs]+=lzy[x]*(r-mid); lzy[x]=0; } void add(int x,int l,int r,int L,int R,int val) { if(l>=L&&r<=R){sum[x]+=val*(r-l+1); lzy[x]+=val; return;} pushdown(x,l,r);// int mid=((l+r)>>1); if(mid>=L)add(x<<1,l,mid,L,R,val); if(mid<R)add(x<<1|1,mid+1,r,L,R,val); pushup(x); } int ask(int x,int l,int r,int L,int R) { if(l>=L&&r<=R)return sum[x]; pushdown(x,l,r);// int ret=0,mid=((l+r)>>1); if(mid>=L)ret+=ask(x<<1,l,mid,L,R); if(mid<R)ret+=ask(x<<1|1,mid+1,r,L,R); return ret; } void update(int x,int y,int val) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); add(1,1,n,dfn[top[x]],dfn[x],val); x=fa[top[x]]; } if(dep[x]<dep[y])swap(x,y); add(1,1,n,dfn[y],dfn[x],val); } bool query(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); if(ask(1,1,n,dfn[top[x]],dfn[x])) return 1; x=fa[top[x]]; } if(dep[x]<dep[y])swap(x,y); return ask(1,1,n,dfn[y],dfn[x]); } int main() { scanf("%d%d",&n,&q); for(int i=1,x,y;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,0); top[1]=1; dfs2(1); for(int i=1,a,b,c,d;i<=q;i++) { scanf("%d%d%d%d",&a,&b,&c,&d); update(a,b,1); if(query(c,d))printf("Y\n"); else printf("N\n"); update(a,b,-1);//!0 } return 0; }
还有 TJ 写法,就是不删除了,打个 int 类型的标记,所以查询 max 即可。
代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; int const maxn=1e5+5; int n,q,hd[maxn],ct,mx[maxn<<2],tim,fa[maxn],lzy[maxn<<2],tot; int dep[maxn],dfn[maxn],top[maxn],son[maxn],siz[maxn],id[maxn]; struct N{ int to,nxt; N(int t=0,int n=0):to(t),nxt(n) {} }ed[maxn<<1]; void add(int x,int y){ed[++ct]=N(y,hd[x]); hd[x]=ct;} void dfs(int x,int f) { dep[x]=dep[f]+1; siz[x]=1; fa[x]=f; for(int i=hd[x],u;i;i=ed[i].nxt) { if((u=ed[i].to)==f)continue; dfs(u,x); siz[x]+=siz[u]; if(siz[u]>siz[son[x]])son[x]=u; } } void dfs2(int x) { dfn[x]=++tim; id[tim]=x;// if(son[x])top[son[x]]=top[x],dfs2(son[x]); for(int i=hd[x],u;i;i=ed[i].nxt) { if((u=ed[i].to)==fa[x]||u==son[x])continue; top[u]=u; dfs2(u); } } void pushup(int x){mx[x]=max(mx[x<<1],mx[x<<1|1]);} void pushdown(int x) { if(!lzy[x])return; lzy[x<<1]=lzy[x<<1|1]=lzy[x]; mx[x<<1]=mx[x<<1|1]=lzy[x]; lzy[x]=0; } void add(int x,int l,int r,int L,int R,int val) { if(l>=L&&r<=R){mx[x]=lzy[x]=val; return;} pushdown(x); int mid=((l+r)>>1); if(mid>=L)add(x<<1,l,mid,L,R,val); if(mid<R)add(x<<1|1,mid+1,r,L,R,val); pushup(x); } int ask(int x,int l,int r,int L,int R) { if(l>=L&&r<=R)return mx[x]; pushdown(x); int ret=0,mid=((l+r)>>1); if(mid>=L)ret=max(ret,ask(x<<1,l,mid,L,R)); if(mid<R)ret=max(ret,ask(x<<1|1,mid+1,r,L,R)); return ret; } void update(int x,int y,int val) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); add(1,1,n,dfn[top[x]],dfn[x],val); x=fa[top[x]]; } if(dep[x]<dep[y])swap(x,y); add(1,1,n,dfn[y],dfn[x],val); } bool query(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]<dep[top[y]])swap(x,y); if(ask(1,1,n,dfn[top[x]],dfn[x])==tot) return 1; x=fa[top[x]]; } if(dep[x]<dep[y])swap(x,y); return (ask(1,1,n,dfn[y],dfn[x])==tot); } int main() { scanf("%d%d",&n,&q); for(int i=1,x,y;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1,0); top[1]=1; dfs2(1); for(int i=1,a,b,c,d;i<=q;i++) { scanf("%d%d%d%d",&a,&b,&c,&d); update(a,b,++tot); if(query(c,d))printf("Y\n"); else printf("N\n"); } return 0; }