P2486
很经典的题~
思路: 线段树染色+"熟练"剖分(某些出题人总是喜欢把序列上的题加个树链剖分搞到树上去)
先想一想序列上怎么做吧
线段树是个好东西
每个节点维护三个信息: ls: 左端点的颜色 rs: 右端点的颜色 cnt: [l, r] 中共有几个颜色段
合并?
fa.cnt = son1.cnt + son2.cnt - [son1.rs == son2.ls]
fa.ls = son1.ls , fa.rs = son2.rs
爹的左端点颜色就是左儿子的左端点颜色, 右端点亦然
如果左儿子与右儿子相接的颜色相同, 那么等于左儿子块数加右儿子块数-1(中间两个块会合成一个)
否则直接加就行啦
修改时要打标记 记录有没有被覆盖
回到树上问题时要特别注意的是询问
因为询问时有swap的操作, 将k记录x,y的顺序, 即相当于(x, y) 还是(y, x)
如果是(y, x), 最后还要反回来才能进行合并
在跳重链时, 总是将链接在它的左边, 最后将a左右儿子反一下再与b合并即可
#include<iostream> #include<cstdio> #include<cstdio> #define ll long long using namespace std; const int N = 105000*4; int fa[N], id[N], siz[N]; int num, dep[N], son[N]; int w[N], wt[N], Top[N]; int h[N], ne[N], to[N]; int tot; inline void add(int x,int y) { ne[++tot] = h[x], h[x] = tot; to[tot] = y; } void dfs1(int x,int f) { fa[x] = f; siz[x] = 1, dep[x] = dep[f] + 1; for (int i = h[x]; i ;i = ne[i]) { int y = to[i]; if (y == f) continue; dfs1(y, x); siz[x] += siz[y]; if (siz[y] > siz[son[x]]) son[x] = y; } } void dfs2(int x,int topf) { id[x] = ++num; wt[num] = w[x]; Top[x] = topf; if (!son[x]) return; dfs2(son[x], topf); for (int i = h[x]; i; i = ne[i]) { int y = to[i]; if (y == fa[x] || y == son[x]) continue; dfs2(y, y); } } int n, m; int L[N], R[N], cnt[N], ls[N], rs[N]; int tag[N]; #define p1 p << 1 #define p2 p << 1 | 1 struct node{ int cnt, ls, rs; }; void update(node &fa,node i,node j) { fa.cnt = i.cnt + j.cnt - (i.rs == j.ls); fa.ls = i.ls, fa.rs = j.rs; } void build(int l,int r,int p) { L[p] = l, R[p] = r; if (l == r) { cnt[p] = 1, ls[p] = rs[p] = wt[l]; return; } int mid = (l + r) >> 1; build(l, mid, p1); build(mid+1, r, p2); cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]); ls[p] = ls[p1], rs[p] = rs[p2]; } void spread(int p) { if (tag[p]) { cnt[p1] = cnt[p2] = 1; tag[p1] = tag[p2] = tag[p]; ls[p1] = ls[p2] = rs[p1] = rs[p2] = tag[p]; tag[p] = 0; } } void change(int l,int r,int p,int c) { if (L[p] >= l && R[p] <= r) { cnt[p] = 1, tag[p] = c; ls[p] = rs[p] = c; return; } spread(p); if (R[p1] >= l) change(l, r, p1, c); if (L[p2] <= r) change(l, r, p2, c); cnt[p] = cnt[p1] + cnt[p2] - (rs[p1] == ls[p2]); ls[p] = ls[p1], rs[p] = rs[p2]; } node ask(int l,int r,int p) { if (L[p] >= l && R[p] <= r) return (node){cnt[p], ls[p], rs[p]}; spread(p); node i; if (R[p1] < l) return ask(l, r, p2); if (L[p2] > r) return ask(l, r, p1); update(i, ask(l, r, p1), ask(l, r, p2)); return i; } void change_e(int x,int y,int c) { while (Top[x] != Top[y]) { if (dep[Top[x]] < dep[Top[y]]) swap(x, y); change(id[Top[x]], id[x], 1, c); x = fa[Top[x]]; } if (dep[x] < dep[y]) swap(x, y); change(id[y], id[x], 1, c); } int sum(int x,int y) { node ans, a, b; ans = a = b = (node){0,0,0}; int k = 1; while (Top[x] != Top[y]) { if (dep[Top[x]] < dep[Top[y]]) swap(x, y), swap(a, b), k ^= 1; if (a.cnt == 0) a = ask(id[Top[x]], id[x], 1); else update(a, ask(id[Top[x]], id[x], 1), a); x = fa[Top[x]]; } if (dep[x] < dep[y]) swap(x, y), swap(a, b); if (a.cnt == 0) a = ask(id[y], id[x], 1); else update(a, ask(id[y], id[x], 1), a); if (b.cnt == 0) return a.cnt; if (a.cnt == 0) return b.cnt; if (!k) swap(a, b); // 将a, b恢复正常顺序 swap(a.ls, a.rs); //将a左右儿子换位 update(ans, a, b); return ans.cnt; } char s[5]; ll a, b, c; template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (;!isdigit(c); c = getchar()) if (c == '-') f = -1; for (;isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } int main() { read(n), read(m); for (int i = 1;i <= n; i++) read(w[i]); for (int i = 1;i < n; i++) { read(a), read(b); add(a, b); add(b, a); } dfs1(1, 0); dfs2(1, 1); build(1, n, 1); while (m--) { scanf ("%s", s + 1); if (s[1] == 'C') { read(a), read(b), read(c); change_e(a, b, c); } else { read(a), read(b); printf ("%d\n", sum(a, b)); } } return 0; }