题意:
给出一个有n个数的数列,并定义mex(l, r)表示数列中第l个元素到第r个元素中第一个没有出现的最小非负整数。
求出这个数列中所有mex的值。
思路:
可以看出对于一个数列,mex(r, r~l)是一个递增序列
mex(0, 0~n-1)是很好求的,只需要遍历找出第一个没有出现的最小非负整数就好了。这里有一个小技巧:
1 tmp = 0;
2 for (int i = 1; i <= n; ++i) {
3 mp[arr[i]] = 1;
4 while (mp.find(tmp) != mp.end()) tmp++;
5 mex[i] = tmp;
6 }
这样可以利用map中的红黑树很快找到第一个没有出现的最小非负整数。
然后在求mex(1~n-1, 0~n-1)的过程中,我们可以看出,每消除当前值arr[i],会影响到的是在下一个arr[i]出现前 往后的mex值中比arr[i]大的值,即如果当前这个值不存在了,那么在这个值下一次出现前,mex值比当前值大的mex值都应被替换成arr[i]。
所以我们可以再一次利用map的红黑树找到当前值下一次出现的位置,然后利用线段树成段更新往后的mex值和求出会影响到的mex值的个数。
1 for (int i = n; i >= 1; --i) {
2 if (mp.find(arr[i]) == mp.end()) next[i] = n+1;
3 else next[i] = mp[arr[i]];
4 mp[arr[i]] = i;
5 }
这里我们还需要利用线段树求出第一个比当前arr[i]大的mex值的位置,以便成段更新区间的mex值。
Tips:
※ 这里有一个小小优化的地方,就是当更新的时候,可以先查看mx[1]是否比当前值大,如果是,则表示往后的区间里有比当前值大的mex值,则需要线段树是需要更新的,否则不用更新。
※ 还有一个要注意的地方是:只有求出的左边界比右边界小的时候才能更新。
Code:

1 #include <stdio.h>
2 #include <cstring>
3 #include <map>
4 #include <algorithm>
5 using namespace std;
6
7 const int MAXN = 200010;
8 long long sum[MAXN<<2];
9 int mx[MAXN<<2], arr[MAXN], next[MAXN], mex[MAXN];
10 int lazy[MAXN<<2];
11
12 void Pushup(int rt)
13 {
14 sum[rt] = sum[rt<<1]+sum[rt<<1|1];
15 mx[rt] = max(mx[rt<<1], mx[rt<<1|1]);
16 }
17
18 void Pushdown(int rt, int x)
19 {
20 if (lazy[rt] != -1) {
21 lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt];
22 sum[rt<<1] = (x-x/2)*lazy[rt];
23 sum[rt<<1|1] = x/2*lazy[rt];
24 mx[rt<<1] = mx[rt<<1|1] = lazy[rt];
25 lazy[rt] = -1;
26 }
27 }
28
29 void Creat(int l, int r, int rt)
30 {
31 lazy[rt] = -1;
32 if (l == r) {
33 sum[rt] = mx[rt] = mex[l];
34 return;
35 }
36 int mid = (l+r)/2;
37 Creat(l, mid, rt<<1);
38 Creat(mid+1, r, rt<<1|1);
39 Pushup(rt);
40 }
41
42 void Modify(int l, int r, int x, int L, int R, int rt)
43 {
44 if (l <= L && r >= R) {
45 lazy[rt] = x;
46 sum[rt] = x*(R-L+1);
47 mx[rt] = x;
48 return;
49 }
50 Pushdown(rt, R-L+1);
51 int mid = (L+R)/2;
52 if (l <= mid) Modify(l, r, x, L, mid, rt<<1);
53 if (r > mid) Modify(l, r, x, mid+1, R, rt<<1|1);
54 Pushup(rt);
55 }
56
57 int Get(int rt, int l, int r, int x)
58 {
59 if(l == r) return l;
60 Pushdown(rt, r-l+1);
61 int mid = (l+r)/2;
62 if (mx[rt<<1] > x) return Get(rt<<1, l, mid, x);
63 else return Get(rt<<1|1, mid+1, r, x);
64 }
65
66 int main()
67 {
68 //freopen("in.txt", "r", stdin);
69 int n, tmp;
70 long long ans_sum;
71 map<int, int> mp;
72 while (~scanf("%d", &n)) {
73 if (n == 0) break;
74 ans_sum = tmp = 0;
75 mp.clear();
76 memset(sum, 0, sizeof(sum));
77 memset(next, 0, sizeof(next));
78
79 for (int i = 1; i <= n; ++i)
80 scanf("%d", &arr[i]);
81 for (int i = 1; i <= n; ++i) {
82 mp[arr[i]] = 1;
83 while (mp.find(tmp) != mp.end()) tmp++;
84 mex[i] = tmp;
85 }
86
87 Creat(1, n, 1);
88 mp.clear();
89 for (int i = n; i >= 1; --i) {
90 if (mp.find(arr[i]) == mp.end()) next[i] = n+1;
91 else next[i] = mp[arr[i]];
92 mp[arr[i]] = i;
93 }
94
95 for (int i = 1; i <= n; ++i) {
96 ans_sum += sum[1];
97 if (mx[1] > arr[i]) {
98 int l = Get(1, 1, n, arr[i]);
99 int r = next[i];
100 // printf("%d %d %d\n", l, r, sum[1]);
101 if (l < r) Modify(l, r-1, arr[i], 1, n, 1);
102 }
103
104 Modify(i, i, 0, 1, n, 1);
105 }
106 printf("%I64d\n", ans_sum);
107 }
108 return 0;
109 }
链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747
来源:https://www.cnblogs.com/Griselda/p/3433595.html
