这是一个有点邪教的两 \(\log\) 做法,并且常数大,过题需要卡。

看到操作和题目名称,不难想到树剖。

不妨设以 \(1\) 为根,一条边的贡献由其深度较大的那个端点计算。难点在于如何安排 dfn 序,使得一条链和与一条链上的点直接相连的点的 dfn 是连续的。重链剖分后对于每条重链处理 dfn 序。考虑把处理 dfn 序的 dfs 改为 bfs,用队列记录待处理重链的顶端结点。首先让 \(1\) 入队,对于队首 \(u\),首先为 \(u\) 所在的重链安排连续的 dfn 序,然后为这条重链上所有点的所有轻儿子安排连续的 dfn 序,再将这条重链上所有点的所有轻儿子入队。发现重链顶端结点的 dfn 序和其他结点并不连续,不过不影响复杂度。

于是修改和查询的时候拆成两条路径就可以跳重链做了,需要注意的是修改第二条路径会对影响第一条的修改(第一条路径中 LCA 的儿子代表的边会被修改为轻边),暴力改回来即可。

时间复杂度 \(\mathcal{O}(n \log^2 n)\),实现起来较为麻烦,且需要卡常(例如我的代码中线段树区间修改和单点修改拆开了)。

#include <bits/stdc++.h>

using namespace std;

#define il inline
#define re register
#define rep(i, s, e) for (re int i = s; i <= e; ++i)
#define drep(i, s, e) for (re int i = s; i >= e; --i)
#define file(a) freopen(#a".in", "r", stdin), freopen(#a".out", "w", stdout)

const int N = 1000000 + 10;

il int read() {
int x = 0; bool f = true; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = false; c = getchar();}
while (isdigit(c)) x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? x : -x;
}

int n, m;
vector <int> e[N];

int dat[N << 2], tag[N << 2];

#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid (l + r >> 1)

void reset(int u, int l, int r) {
dat[u] = 0, tag[u] = -1;
if (l == r) return;
reset(ls, l, mid), reset(rs, mid + 1, r);
}

il void pushdown(int u, int l, int r) {
if (tag[u] == -1) return;
dat[ls] = tag[u] * (mid - l + 1), dat[rs] = tag[u] * (r - mid);
tag[ls] = tag[u], tag[rs] = tag[u], tag[u] = -1;
}

void modify(int ml, int mr, int k, int u, int l, int r) {
if (ml <= l && r <= mr) { dat[u] = k * (r - l + 1), tag[u] = k; return; }
pushdown(u, l, r);
if (ml <= mid) modify(ml, mr, k, ls, l, mid);
if (mr > mid) modify(ml, mr, k, rs, mid + 1, r);
dat[u] = dat[ls] + dat[rs];
}

void single_modify(int p, int k, int u, int l, int r) {
if (l == r) { dat[u] = tag[u] = k; return; }
pushdown(u, l, r);
if (p <= mid) single_modify(p, k, ls, l, mid);
else single_modify(p, k, rs, mid + 1, r);
dat[u] = dat[ls] + dat[rs];
}

int query(int ql, int qr, int u, int l, int r) {
if (ql <= l && r <= qr) return dat[u];
pushdown(u, l, r);
int res = 0;
if (ql <= mid) res += query(ql, qr, ls, l, mid);
if (qr > mid) res += query(ql, qr, rs, mid + 1, r);
return res;
}

int sz[N], mson[N], dep[N], fr[N];
void dfs1(int u, int fa) {
sz[u] = 1, mson[u] = 0, fr[u] = fa;
for (int v : e[u]) {
if (v == fa) continue;
dep[v] = dep[u] + 1, dfs1(v, u), sz[u] += sz[v];
if (sz[v] > sz[mson[u]]) mson[u] = v;
}
}

int top[N], dfn[N], dcnt, lef[N], rig[N];
void dfs2(int u, int fa, int tp) {
top[u] = tp;
if (mson[u]) dfs2(mson[u], u, tp);
for (int v : e[u]) {
if (v != fa && v != mson[u]) dfs2(v, u, v);
}
}

void make_dfn() {
queue <int> q;
q.push(1);
while (q.size()) {
int u = q.front(); q.pop();
for (int v = mson[u]; v; v = mson[v]) dfn[v] = ++ dcnt;
for (int v = u; v; v = mson[v]) {
lef[v] = dcnt + 1;
for (int w : e[v]) if (w != fr[v] && w != mson[v]) dfn[w] = ++ dcnt, q.push(w);
rig[v] = dcnt;
}
}
}

int LCA(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fr[top[u]];
}
return (dep[u] < dep[v]) ? u : v;
}

void doit(int u, int lca) {
int v = 0;
while (dep[lca] < dep[top[u]]) {
if (mson[u]) single_modify(dfn[mson[u]], 0, 1, 1, n);
if (lef[top[u]] <= rig[u]) modify(lef[top[u]], rig[u], 0, 1, 1, n);
if (u != top[u]) modify(dfn[mson[top[u]]], dfn[u], 1, 1, 1, n);
if (v) single_modify(dfn[v], 1, 1, 1, n);
u = top[u], v = u, single_modify(dfn[u], 1, 1, 1, n), u = fr[u];
}
if (dfn[lca]) single_modify(dfn[lca], 0, 1, 1, n);
if (mson[u]) single_modify(dfn[mson[u]], 0, 1, 1, n);
if (lef[lca] <= rig[u]) modify(lef[lca], rig[u], 0, 1, 1, n);
if (u != lca) modify(dfn[mson[lca]], dfn[u], 1, 1, 1, n);
if (v) single_modify(dfn[v], 1, 1, 1, n);
}

void change(int u, int v) {
int lca = LCA(u, v);
if (u == lca) doit(v, u);
else if (v == lca) doit(u, v);
else {
doit(u, lca), doit(v, lca);
int w;
for (w = u; dep[top[w]] > dep[lca]; w = fr[top[w]]) {
w = top[w];
if (fr[w] == lca) single_modify(dfn[w], 1, 1, 1, n);
}
if (w != lca) single_modify(dfn[mson[lca]], 1, 1, 1, n);
}
}

int askit(int u, int lca) {
int res = 0;
while (dep[lca] < dep[top[u]]) {
if (u != top[u]) res += query(dfn[mson[top[u]]], dfn[u], 1, 1, n);
u = top[u], res += query(dfn[u], dfn[u], 1, 1, n), u = fr[u];
}
if (u != lca) res += query(dfn[mson[lca]], dfn[u], 1, 1, n);
return res;
}

int ask(int u, int v) {
int lca = LCA(u, v);
return askit(u, lca) + askit(v, lca);
}

int main() {
int tc = read();
while (tc --) {
n = read(), m = read(), dcnt = 0;
reset(1, 1, n);
rep(i, 1, n) e[i].clear();
rep(i, 1, n - 1) {
int u = read(), v = read();
e[u].push_back(v), e[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, 1);
make_dfn();
while (m --) {
int op = read(), u = read(), v = read();
if (op == 1) change(u, v);
else printf("%d\n", ask(u, v));
}
}
return 0;
}