题解:LCA+线段树


#include <bits/stdc++.h>
#define ll long long
using namespace std;
namespace _{
    char buf[100000], *p1 = buf, *p2 = buf; bool rEOF = 1;//为0表示文件结尾
    inline char nc(){ return p1 == p2 && rEOF && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? (rEOF = 0, EOF) : *p1++; }
    template<class _T>
    inline void read(_T &num){
        char c = nc(), f = 1; num = 0;
        while (c<'0' || c>'9')c == '-' && (f = -1), c = nc();
        while (c >= '0'&&c <= '9')num = num * 10 + c - '0', c = nc();
        num *= f;
    }
    inline bool need(char &c){ return c >= 'a'&&c <= 'z' || c >= '0'&&c <= '9' || c >= 'A'&&c <= 'Z'; }//读入的字符范围
    inline void read_str(char *a){
        while ((*a = nc()) && !need(*a) && rEOF);   ++a;
        while ((*a = nc()) && need(*a) && rEOF)++a; --p1, *a = '\0';
    }
    template<class _T>
    inline void println(_T x){
        static int buf[30], len; len = 0; do buf[len++] = x % 10, x /= 10; while (x);
        while (len) putchar(buf[--len] + 48);
        putchar('\n');
    }
}using namespace _;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9+7;
const int MAXN = 500005;
 
int n, q, k;
int a[MAXN];
 
struct edge{
    int v,next;
}e[MAXN<<1];
int head[MAXN], cnt;
inline void add(int u,int v){
    e[cnt] = (edge){v,head[u]}, head[u] = cnt++;
}
 
int dep[MAXN], low[MAXN], dfn[MAXN], tot, fa[MAXN][20];
void dfs(int u,int f){
    dep[u] = dep[f] + 1, dfn[u] = ++tot, fa[u][0] = f;
    for(int i = 1;i<20;i++)fa[u][i] = fa[fa[u][i-1]][i-1];
    for(int i = head[u];~i;i = e[i].next){
        if(e[i].v == f)continue;
        dfs(e[i].v, u);
    }
    low[u] = tot;
}
inline int lca(int u,int v){
    if(dep[u]<dep[v])swap(u,v);
    for(int i = 0, del = dep[u]-dep[v];del;i++, del>>=1)
        if(del&1)u = fa[u][i];
    if(u == v)return u;
    for(int i = 19;i>=0;i--)
        if(fa[u][i]!=fa[v][i])u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}
#define mid ((l+r)>>1)
struct Tree{
    int lazy, sum;
}T[MAXN<<2];
inline void pushdown(int u,int l,int r){
    if(~T[u].lazy){
        T[u<<1].lazy = T[u<<1|1].lazy = T[u].lazy;
        T[u<<1].sum = T[u].lazy*(mid-l+1);
        T[u<<1|1].sum = T[u].lazy*(r-mid);
        T[u].lazy = -1;
    }
}
void update(int u,int l,int r,int L,int R,int x){
    if(L<=l&&r<=R){
        T[u].lazy = x;
        T[u].sum = x*(r-l+1);
        return;
    }
    pushdown(u,l,r);
    if(L<=mid)update(u<<1,l,mid,L,R,x);
    if(R>mid)update(u<<1|1,mid+1,r,L,R,x);
    T[u].sum = T[u<<1].sum + T[u<<1|1].sum;
}
int main() {
    memset(head,-1,sizeof head);
    read(n), read(q);
    for(int i = 1,u,v;i<n;i++){
        read(u), read(v), add(u,v), add(v,u);
    }
    dfs(1,1);
    while(q--){
        read(k); for(int i = 1;i<=k;i++) read(a[i]);
        if(k == 1){
            println(n); continue;
        }
        T[1].lazy = 0;
        int rt = a[1];
        for(int i = 2;i<=k;i++){
            if(dfn[rt] <= dfn[a[i]] && dfn[a[i]] <= low[rt]){   //情况1,在rt的子树中
                int d = ((dep[a[i]] + dep[rt])>>1) + 1, u = a[i];
                for(int j = 19;j>=0;j--)
                    if(dep[fa[u][j]] >= d) u = fa[u][j];
                update(1,1,n,dfn[u],low[u],1);        //对u整个子树打标记
            }else{
                int lc = lca(rt, a[i]);
                int dis = ((dep[a[i]] + dep[rt] - (dep[lc]<<1))>>1)+1;
                if(dep[a[i]] >= dep[rt]){                        //情况2
                    int d = dis - dep[rt] + (dep[lc]<<1), u = a[i];
                    for(int j = 19;j>=0;j--)
                        if(dep[fa[u][j]] >= d) u = fa[u][j];
                    update(1,1,n,dfn[u],low[u],1);    //对u整个子树打标记
                }else{                                           //情况3
                    int d = dep[rt] - dis + 1, u = rt;    //d表示u的深度
                    for(int j = 19;j>=0;j--)
                        if(dep[fa[u][j]] >= d) u = fa[u][j];
                    if(1 <= dfn[u]-1)update(1,1,n,1,dfn[u]-1,1); //对u子树外的点打标记
                    if(low[u]+1 <= n)update(1,1,n,low[u]+1,n,1); //对u子树外的点打标记
                }
            }
        }
        println(n-T[1].sum);    //答案为总点数-被标记点数
    }
    return 0;
}