洛谷 P2572 [SCOI2010]序列操作

匿名 (未验证) 提交于 2019-12-02 23:49:02

维护一个序列,支持如下操作

  • 把[a, b]区间内的所有数全变成0
  • 把[a, b]区间内的所有数全变成1
  • 把[a,b]区间内所有的0变成1,所有的1变成0
  • 询问[a, b]区间内总共有多少个1
  • 询问[a, b]区间内最多有多少个连续的1

线段树
对于每个节点,维护对应区间

  • sum:1的个数
  • L0:连续0的最大长度
  • L1:连续1的最长长度
  • l0:包含区间左端点的连续0的最大长度
  • l1:包含区间左端点的连续1的最大长度
  • r0:包含区间右端点的连续0的最大长度
  • r1:包含区间右端点的连续1的最大长度
  • la:-1表示不变,0表示全为0,1表示全为1
  • tn:0表示不变,1表示翻转

可推出关系
ls表示左儿子,rs表示右儿子,lenl表示左儿子区间长度,lenr表示右儿子区间长度
\(L0 = max(L0_{ls}, L0_{rs}, r0_{ls}+l0_{rs})\)
\(l0 = l0_{ls} + l0_{rs} * (l0_{ls} == lenl)\)
\(r0 = r0_{rs} + r0_{ls} * (r0_{rs} == lenr)\)
L1,l1,r1同理可得。
下传标记时先判断la,因为全都变为1个数可以覆盖翻转的结果

#include <cstdio> #include <algorithm> #define ci const int #define ls x << 1 #define rs x << 1 | 1 ci Maxn = 100000; typedef int Array[Maxn << 2]; int n, m, opt, a, b, tmp; struct Node { int sum, L1, l1, r1; }; struct Segement_Tree {     Array sum, la, L0, L1, l0, l1, r0, r1, tn;     void push_up(ci& x, ci& lenl, ci& lenr)     {         sum[x] = sum[ls] + sum[rs];         L0[x] = std::max(std::max(L0[ls], L0[rs]), r0[ls] + l0[rs]);         L1[x] = std::max(std::max(L1[ls], L1[rs]), r1[ls] + l1[rs]);         l0[x] = l0[ls] + l0[rs] * (l0[ls] == lenl);         l1[x] = l1[ls] + l1[rs] * (l1[ls] == lenl);         r0[x] = r0[rs] + r0[ls] * (r0[rs] == lenr);         r1[x] = r1[rs] + r1[ls] * (r1[rs] == lenr);     }     void push_down(ci& x, ci& lenl, ci& lenr)     {         if (la[x] == 0)         {             sum[ls] = la[ls] = L1[ls] = l1[ls] = r1[ls] = tn[ls] = 0;             sum[rs] = la[rs] = L1[rs] = l1[rs] = r1[rs] = tn[rs] = 0;             L0[ls] = l0[ls] = r0[ls] = lenl;             L0[rs] = l0[rs] = r0[rs] = lenr;             la[x] = -1;         }         if (la[x] == 1)         {             sum[ls] = L1[ls] = l1[ls] = r1[ls] = lenl;             sum[rs] = L1[rs] = l1[rs] = r1[rs] = lenr;             L0[ls] = l0[ls] = r0[ls] = tn[ls] = 0;             L0[rs] = l0[rs] = r0[rs] = tn[rs] = 0;             la[ls] = la[rs] = 1; la[x] = -1;         }         if (tn[x])         {             sum[ls] = lenl - sum[ls]; sum[rs] = lenr - sum[rs];             std::swap(L0[ls], L1[ls]); std::swap(L0[rs], L1[rs]);             std::swap(l0[ls], l1[ls]); std::swap(l0[rs], l1[rs]);             std::swap(r0[ls], r1[ls]); std::swap(r0[rs], r1[rs]);             tn[ls] ^= 1; tn[rs] ^= 1;             la[x] = -1; tn[x] = 0;         }     }     void build(ci& x, ci& l, ci& r)     {         la[x] = -1;         if (l == r)         {             scanf("%d", &tmp);             sum[x] = L1[x] = l1[x] = r1[x] = tmp;             L0[x] = l0[x] = r0[x] = tmp ^ 1;             tn[x] = 0;             return;         }         int mid = (l + r) >> 1;         build(ls, l, mid); build(rs, mid + 1, r);         push_up(x, mid - l + 1, r - mid);     }     void change(ci& x, ci& L, ci& R, ci& l, ci& r, ci& k)     {         if (L <= l && r <= R)         {             sum[x] = L1[x] = l1[x] = r1[x] = k ? r - l + 1 : 0;             L0[x] = l0[x] = r0[x] = k ? 0 : r - l + 1;             la[x] = k; tn[x] = 0;             return;         }         if (R < l || r < L) return;         int mid = (l + r) >> 1;         push_down(x, mid - l + 1, r - mid);         change(ls, L, R, l, mid, k); change(rs, L, R, mid + 1, r, k);         push_up(x, mid - l + 1, r - mid);     }     void turn(ci& x, ci& L, ci& R, ci& l, ci& r)     {         if (L <= l && r <= R)         {             sum[x] = r - l + 1 - sum[x];             std::swap(L0[x], L1[x]);             std::swap(l0[x], l1[x]);             std::swap(r0[x], r1[x]);             tn[x] ^= 1;             return;         }         if (R < l || r < L) return;         int mid = (l + r) >> 1;         push_down(x, mid - l + 1, r - mid);         turn(ls, L, R, l, mid); turn(rs, L, R, mid + 1, r);         push_up(x, mid - l + 1, r - mid);     }     Node query(ci& x, ci& L, ci& R, ci& l, ci& r)     {         if (L <= l && r <= R) return (Node){sum[x], L1[x], l1[x], r1[x]};         if (R < l || r < L) return (Node){0, 0, 0, 0};         int mid = (l + r) >> 1;         push_down(x, mid - l + 1, r - mid);         Node s1 = query(ls, L, R, l, mid);         Node s2 = query(rs, L, R, mid + 1, r);         Node s;         s.sum = s1.sum + s2.sum;         s.L1 = std::max(std::max(s1.L1, s2.L1), s1.r1 + s2.l1);         s.l1 = s1.l1 + s2.l1 * (s1.l1 == mid - l + 1);         s.r1 = s2.r1 + s1.r1 * (s2.r1 == r - mid);         return s;     } }sgt; int main() {     scanf("%d%d", &n, &m);     sgt.build(1, 1, n);     for (register int i = 1; i <= m; ++i)     {         scanf("%d%d%d", &opt, &a, &b); ++a; ++b;         if (opt == 4) printf("%d\n", sgt.query(1, a, b, 1, n).L1);         else if (opt == 3) printf("%d\n", sgt.query(1, a, b, 1, n).sum);         else if (opt == 2) sgt.turn(1, a, b, 1, n);         else sgt.change(1, a, b, 1, n, opt);     } }
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!