维护一个序列,支持如下操作
- 把[a, b]区间内的所有数全变成0
- 把[a, b]区间内的所有数全变成1
- 把[a,b]区间内所有的0变成1,所有的1变成0
- 询问[a, b]区间内总共有多少个1
- 询问[a, b]区间内最多有多少个连续的1
线段树
对于每个节点,维护对应区间
- sum:1的个数
- L0:连续0的最大长度
- L1:连续1的最长长度
- l0:包含区间左端点的连续0的最大长度
- l1:包含区间左端点的连续1的最大长度
- r0:包含区间右端点的连续0的最大长度
- r1:包含区间右端点的连续1的最大长度
- la:-1表示不变,0表示全为0,1表示全为1
- tn:0表示不变,1表示翻转
可推出关系
ls表示左儿子,rs表示右儿子,lenl表示左儿子区间长度,lenr表示右儿子区间长度
\(L0 = max(L0_{ls}, L0_{rs}, r0_{ls}+l0_{rs})\)
\(l0 = l0_{ls} + l0_{rs} * (l0_{ls} == lenl)\)
\(r0 = r0_{rs} + r0_{ls} * (r0_{rs} == lenr)\)
L1,l1,r1同理可得。
下传标记时先判断la,因为全都变为1个数可以覆盖翻转的结果
#include <cstdio> #include <algorithm> #define ci const int #define ls x << 1 #define rs x << 1 | 1 ci Maxn = 100000; typedef int Array[Maxn << 2]; int n, m, opt, a, b, tmp; struct Node { int sum, L1, l1, r1; }; struct Segement_Tree { Array sum, la, L0, L1, l0, l1, r0, r1, tn; void push_up(ci& x, ci& lenl, ci& lenr) { sum[x] = sum[ls] + sum[rs]; L0[x] = std::max(std::max(L0[ls], L0[rs]), r0[ls] + l0[rs]); L1[x] = std::max(std::max(L1[ls], L1[rs]), r1[ls] + l1[rs]); l0[x] = l0[ls] + l0[rs] * (l0[ls] == lenl); l1[x] = l1[ls] + l1[rs] * (l1[ls] == lenl); r0[x] = r0[rs] + r0[ls] * (r0[rs] == lenr); r1[x] = r1[rs] + r1[ls] * (r1[rs] == lenr); } void push_down(ci& x, ci& lenl, ci& lenr) { if (la[x] == 0) { sum[ls] = la[ls] = L1[ls] = l1[ls] = r1[ls] = tn[ls] = 0; sum[rs] = la[rs] = L1[rs] = l1[rs] = r1[rs] = tn[rs] = 0; L0[ls] = l0[ls] = r0[ls] = lenl; L0[rs] = l0[rs] = r0[rs] = lenr; la[x] = -1; } if (la[x] == 1) { sum[ls] = L1[ls] = l1[ls] = r1[ls] = lenl; sum[rs] = L1[rs] = l1[rs] = r1[rs] = lenr; L0[ls] = l0[ls] = r0[ls] = tn[ls] = 0; L0[rs] = l0[rs] = r0[rs] = tn[rs] = 0; la[ls] = la[rs] = 1; la[x] = -1; } if (tn[x]) { sum[ls] = lenl - sum[ls]; sum[rs] = lenr - sum[rs]; std::swap(L0[ls], L1[ls]); std::swap(L0[rs], L1[rs]); std::swap(l0[ls], l1[ls]); std::swap(l0[rs], l1[rs]); std::swap(r0[ls], r1[ls]); std::swap(r0[rs], r1[rs]); tn[ls] ^= 1; tn[rs] ^= 1; la[x] = -1; tn[x] = 0; } } void build(ci& x, ci& l, ci& r) { la[x] = -1; if (l == r) { scanf("%d", &tmp); sum[x] = L1[x] = l1[x] = r1[x] = tmp; L0[x] = l0[x] = r0[x] = tmp ^ 1; tn[x] = 0; return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); push_up(x, mid - l + 1, r - mid); } void change(ci& x, ci& L, ci& R, ci& l, ci& r, ci& k) { if (L <= l && r <= R) { sum[x] = L1[x] = l1[x] = r1[x] = k ? r - l + 1 : 0; L0[x] = l0[x] = r0[x] = k ? 0 : r - l + 1; la[x] = k; tn[x] = 0; return; } if (R < l || r < L) return; int mid = (l + r) >> 1; push_down(x, mid - l + 1, r - mid); change(ls, L, R, l, mid, k); change(rs, L, R, mid + 1, r, k); push_up(x, mid - l + 1, r - mid); } void turn(ci& x, ci& L, ci& R, ci& l, ci& r) { if (L <= l && r <= R) { sum[x] = r - l + 1 - sum[x]; std::swap(L0[x], L1[x]); std::swap(l0[x], l1[x]); std::swap(r0[x], r1[x]); tn[x] ^= 1; return; } if (R < l || r < L) return; int mid = (l + r) >> 1; push_down(x, mid - l + 1, r - mid); turn(ls, L, R, l, mid); turn(rs, L, R, mid + 1, r); push_up(x, mid - l + 1, r - mid); } Node query(ci& x, ci& L, ci& R, ci& l, ci& r) { if (L <= l && r <= R) return (Node){sum[x], L1[x], l1[x], r1[x]}; if (R < l || r < L) return (Node){0, 0, 0, 0}; int mid = (l + r) >> 1; push_down(x, mid - l + 1, r - mid); Node s1 = query(ls, L, R, l, mid); Node s2 = query(rs, L, R, mid + 1, r); Node s; s.sum = s1.sum + s2.sum; s.L1 = std::max(std::max(s1.L1, s2.L1), s1.r1 + s2.l1); s.l1 = s1.l1 + s2.l1 * (s1.l1 == mid - l + 1); s.r1 = s2.r1 + s1.r1 * (s2.r1 == r - mid); return s; } }sgt; int main() { scanf("%d%d", &n, &m); sgt.build(1, 1, n); for (register int i = 1; i <= m; ++i) { scanf("%d%d%d", &opt, &a, &b); ++a; ++b; if (opt == 4) printf("%d\n", sgt.query(1, a, b, 1, n).L1); else if (opt == 3) printf("%d\n", sgt.query(1, a, b, 1, n).sum); else if (opt == 2) sgt.turn(1, a, b, 1, n); else sgt.change(1, a, b, 1, n, opt); } }