题目
Description
给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。
Sample Input
6 5
2 2 1 2 1 1
1 2
1 3
2 4
2 5
2 6
Q 3 5
C 2 1 1
Q 3 5
C 5 1 2
Q 3 5
Sample Output
3
1
2
HINT
数N<=10^5,操作数M<=10^5,所有的颜色C为整数且在[0, 10^9]之间。
分析
我实在是懒得分析了emm...这题目调了半天!!!
还是写写吧
正解: 树剖+线段树
因为要判断颜色的段数,所以想到在线段树里要维护个左右端点的颜色,并且在查询的时候要判断颜色,为了防止算重复。
线段树的操作就不写了(垃圾博主就是忘了写下传tag到儿子的tag, 他脑子让驴给踢了,不用管他)(注意,线段树的query可能要变一下),考虑查询操作:当前链与上一次的链在相交的边缘可能颜色相同,如果颜色相同答案需要减一。所以统计答案的时候要记录下上一次剖到的链所在线段树区间的左端点,每次与当前链所在线段树的右端点比较(想想线段树的查询和in[]数组)
又由于有x和y两个位置在向上走,那么要记录ans1,ans2两个变量来存“上一次的左端点颜色”, 每次交换x,y时记得交换ans1,ans2
注意最后在同一条重链上时情况不一样,需要自己手胡一下
#include<cstdio> #include<iostream> #include<algorithm> using namespace std; const int MAXN = 100000+99; const int MAXM = MAXN<<1; int n, m; int pos[MAXN]; struct node{ int deep, size, fa, son, tp, in, color; }a[MAXN]; int _clock; struct seg{ int y, next; }e[MAXM]; int head[MAXN], cnt; void add_edge(int x, int y) { e[++cnt].y = y; e[cnt].next = head[x]; head[x] = cnt; } void dfs1(int x, int fa) { a[x].deep = a[fa].deep + 1; a[x].fa = fa; a[x].size = 1; for(int i = head[x]; i; i = e[i].next) if(e[i].y != fa) { dfs1(e[i].y, x); a[x].size += a[e[i].y].size; a[x].son = a[a[x].son].size > a[e[i].y].size ? a[x].son : e[i].y; } } struct tree{ int mx, L, R, lazyset; }tr[MAXN<<2]; void dfs2(int x, int tp) { a[x].tp = tp; a[x].in = ++_clock; pos[_clock] = a[x].color; if(a[x].son) dfs2(a[x].son, tp); for(int i = head[x]; i; i = e[i].next) if(e[i].y != a[x].fa && e[i].y != a[x].son) { dfs2(e[i].y, e[i].y); } } void pushup(int o) { tr[o].mx = tr[o<<1].mx + tr[o<<1|1].mx; if(tr[o<<1].R == tr[o<<1|1].L) tr[o].mx--; tr[o].L = tr[o<<1].L; tr[o].R = tr[o<<1|1].R;//别写漏了 } void build(int o, int l, int r) { tr[o].lazyset = 0; if(l == r) { tr[o].L = tr[o].R = pos[l]; tr[o].mx = 1; return ; } int mid = (l+r)>>1; build(o<<1, l, mid); build(o<<1|1, mid+1, r); pushup(o); } void pushdown(int o) { if(tr[o].lazyset == 0) return ; tr[o<<1].L = tr[o<<1].R = tr[o].lazyset ; tr[o<<1|1].L = tr[o<<1|1].R = tr[o].lazyset ; tr[o<<1].mx = tr[o<<1|1].mx = 1; tr[o<<1].lazyset = tr[o<<1|1].lazyset = tr[o].lazyset ; tr[o].lazyset = 0; } void optset(int o, int l, int r, int ql, int qr, int k) { if(ql <= l && r <= qr) { tr[o].L = tr[o].R = k; tr[o].mx = 1; tr[o].lazyset = k; return ; } pushdown(o); int mid = (l+r)>>1; if(ql <= mid) optset(o<<1, l, mid, ql, qr, k); if(mid < qr) optset(o<<1|1, mid+1, r, ql, qr, k); pushup(o); } int Lcolor, Rcolor; int query(int o, int l, int r, int ql, int qr) { if(l == ql) Lcolor = tr[o].L; if(r == qr) Rcolor = tr[o].R; if(ql <= l && r <= qr) { return tr[o].mx ; } int mid = (l+r)>>1, ans = 0; pushdown(o); //需要求出Lcolor和Rcolor,所以要像下面这样写...? if(qr <= mid) { return query(o<<1, l, mid, ql, qr); } else if(mid < ql) { return query(o<<1|1, mid+1, r, ql, qr); } else { ans += query(o<<1, l, mid, ql, qr); ans += query(o<<1|1, mid+1, r,ql, qr); if(tr[o<<1].R == tr[o<<1|1].L) ans--; return ans; } //线段树查询的时候也要考虑 } void ttt_update(int x, int y, int k) { while(a[x].tp != a[y].tp) { if(a[a[x].tp].deep < a[a[y].tp].deep) swap(x,y); optset(1, 1, n, a[a[x].tp].in, a[x].in, k); x = a[a[x].tp].fa; } if(a[x].deep > a[y].deep) swap(x,y); optset(1, 1, n, a[x].in, a[y].in, k); } int ttt_query(int x, int y) { int ans = 0, ans1 = -1, ans2 = -1; //ans1,ans2分别记录x,y 上一条被剖的链所在线段树的区间的左端点 //每次与当前链所在线段树的右端点比较(想想线段树的查询和in[]数组) while(a[x].tp != a[y].tp) { if(a[a[x].tp].deep < a[a[y].tp].deep) { swap(x, y); swap(ans1, ans2); } ans += query(1, 1, n, a[a[x].tp].in, a[x].in); if(ans1 == Rcolor) { ans--; } ans1 = Lcolor; x = a[a[x].tp].fa; } if(a[x].deep > a[y].deep) { swap(x, y); swap(ans1, ans2); } // if(a[x].deep < a[y].deep) // swap(x,y),swap(ans1,ans2); ans += query(1, 1, n, a[x].in, a[y].in); if(ans1 == Lcolor) ans--; if(ans2 == Rcolor) ans--; return ans; } int main() { scanf("%d%d",&n,&m); for(int i = 1; i <= n; i++) scanf("%d",&a[i].color); int x,y; for(int i = 1; i < n; i++) { scanf("%d%d",&x, &y); add_edge(x,y); add_edge(y,x); } dfs1(1, 0); dfs2(1, 1); build(1, 1, n); char cmd; int k; for(int i = 1; i <= m; i++) { cin>>cmd; if(cmd == 'C') { scanf("%d%d%d",&x,&y,&k); ttt_update(x, y, k); } else { scanf("%d%d",&x,&y); printf("%d\n",ttt_query(x,y)); } } }