题意:

给出一棵树,每个顶点上有个颜色\(c_i\)
有两种操作:

  • C a b c 将\(a \to b\)的路径所有顶点上的颜色变为c
  • Q a b 查询\(a \to b\)的路径上的颜色段数,连续相同颜色视为一段

分析:

首先树链剖分,下面考虑线段树部分:
我们维护一个区间的左端点的颜色和右断点的颜色以及该区间的颜色段数,在加一个颜色覆盖标记。
pushup的时候,如果左区间右端点颜色和右区间左端点颜色相同,那么这段颜色可以合并,合并区间的颜色段数为左右子区间颜色段数之和减1;
否则,答案为左右子区间颜色段数之和。

本题的特殊性在于区间合并的顺序性,我们是自底向上将两个顶点\(LCA\)的。因为在每条重链上,顶点在线段树上的编号是从上到下递增的。所以每个子查询得到的区间信息也是从上到下的。我们可以将所得区间左右翻转(具体就是交换区间左右端点颜色,颜色段数不会变)一下,再合并最终得到整个查询区间。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
 
const int maxn = 100000 + 10;
const int maxnode = maxn * 4;
 
struct Edge
{
    int v, nxt;
    Edge() {}
    Edge(int v, int nxt):v(v), nxt(nxt) {}
};
 
int n, m, a[maxn];
 
int ecnt, head[maxn];
Edge edges[maxn * 2];
 
void AddEdge(int u, int v) {
    edges[ecnt] = Edge(v, head[u]);
    head[u] = ecnt++;
    edges[ecnt] = Edge(u, head[v]);
    head[v] = ecnt++;
}
 
int fa[maxn], dep[maxn], sz[maxn], son[maxn];
int tot, top[maxn], id[maxn], pos[maxn];
 
void dfs(int u) {
    sz[u] = 1; son[u] = 0;
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u]) continue;
        dep[v] = dep[u] + 1;
        fa[v] = u;
        dfs(v);
        sz[u] += sz[v];
        if(sz[v] > sz[son[u]]) son[u] = v;
    }
}
 
void dfs2(int u, int tp) {
    id[u] = ++tot;
    pos[tot] = u;
    top[u] = tp;
    if(!son[u]) return;
    dfs2(son[u], tp);
    for(int i = head[u]; ~i; i = edges[i].nxt) {
        int v = edges[i].v;
        if(v == fa[u] || v == son[u]) continue;
        dfs2(v, v);
    }
}
 
struct Node
{
    int lcol, rcol, cntv;
 
    Node() {}
 
    Node(int l, int r, int c): lcol(l), rcol(r), cntv(c) {}
};
 
void reverse(Node& t) { swap(t.lcol, t.rcol); }
 
Node operator + (const Node& a, const Node& b) {
    if(!a.cntv) return b; if(!b.cntv) return a;
    return Node(a.lcol, b.rcol, a.cntv + b.cntv - 1 + (a.rcol != b.lcol));
}
 
int setv[maxnode];
Node t[maxnode];
 
void build(int o, int L, int R) {
    setv[o] = -1;
    if(L == R) {
        t[o].cntv = 1;
        t[o].lcol = t[o].rcol = a[pos[L]];
        return;
    }
    int M = (L + R) / 2;
    build(o<<1, L, M);
    build(o<<1|1, M+1, R);
    t[o] = t[o<<1] + t[o<<1|1];
}
 
void pushdown(int o) {
    if(setv[o] != -1) {
        setv[o<<1] = setv[o<<1|1] = setv[o];
        t[o<<1].cntv = t[o<<1|1].cntv = 1;
        t[o<<1].lcol = t[o<<1].rcol = t[o<<1|1].lcol = t[o<<1|1].rcol = setv[o];
        setv[o] = -1;
    }
}
 
void update(int o, int L, int R, int qL, int qR, int v) {
    if(qL <= L && R <= qR) {
        t[o].lcol = t[o].rcol = setv[o] = v;
        t[o].cntv = 1;
        return;
    }
    pushdown(o);
    int M = (L + R) / 2;
    if(qL <= M) update(o<<1, L, M, qL, qR, v);
    if(qR > M) update(o<<1|1, M+1, R, qL, qR, v);
    t[o] = t[o<<1] + t[o<<1|1];
}
 
void UPDATE(int u, int v, int val) {
    int t1 = top[u], t2 = top[v];
    while(t1 != t2) {
        if(dep[t1] < dep[t2]) { swap(u, v); swap(t1, t2); }
        update(1, 1, n, id[t1], id[u], val);
        u = fa[t1]; t1 = top[u];
    }
    if(dep[u] > dep[v]) swap(u, v);
    update(1, 1, n, id[u], id[v], val);
}
 
Node query(int o, int L, int R, int qL, int qR) {
    Node ans(0, 0, 0);
    if(qL <= L && R <= qR) return t[o];
    pushdown(o);
    int M = (L + R) / 2;
    if(qL <= M) ans = ans + query(o<<1, L, M, qL, qR);
    if(qR > M) ans = ans + query(o<<1|1, M+1, R, qL, qR);
    return ans;
}
 
int QUERY(int u, int v) {
    Node q1(0, 0, 0), q2(0, 0, 0), tmp;
    int t1 = top[u], t2 = top[v];
    while(t1 != t2) {
        if(dep[t1] > dep[t2]) {
            tmp = query(1, 1, n, id[t1], id[u]);
            reverse(tmp);
            q1 = q1 + tmp;
            u = fa[t1]; t1 = top[u];
        } else {
            tmp = query(1, 1, n, id[t2], id[v]);
            reverse(tmp);
            q2 = q2 + tmp;
            v = fa[t2]; t2 = top[v];
        }
    }
    if(dep[u] > dep[v]) {
        tmp = query(1, 1, n, id[v], id[u]);
        reverse(tmp);
        q1 = q1 + tmp;
    } else {
        tmp = query(1, 1, n, id[u], id[v]);
        reverse(tmp);
        q2 = q2 + tmp;
    }
 
    reverse(q2);
    q1 = q1 + q2;
    return q1.cntv;
}
 
int main()
{
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) scanf("%d", a + i);
 
    ecnt = 0;
    memset(head, -1, sizeof(head));
    for(int i = 1; i < n; i++) {
        int u, v; scanf("%d%d", &u, &v);
        AddEdge(u, v);
    }
 
    dfs(1);
    tot = 0;
    dfs2(1, 1);
 
    build(1, 1, n);
 
    char cmd[5];
    int a, b, c;
    while(m--) {
        scanf("%s", cmd);
        scanf("%d%d", &a, &b);
        if(cmd[0] == 'C') {
            scanf("%d", &c);
            UPDATE(a, b, c);
        } else {
            printf("%d\n", QUERY(a, b));
        }
    }
 
    return 0;
}