权值线段树学习笔记
参考博文:
https://www.cnblogs.com/zmyzmy/p/9529234.html
权值线段树:
- 权值线段树维护数的个数,数组下标代表整个值域,如果太大可以采用离散化。
定义:
struct SegmentTree { int l, r; int s; //节点p的s表示这一段值域中数的个数总和 #define l(x) tree[x].l #define r(x) tree[x].r #define s(x) tree[x].s #define lson (p<<1) #define rson (p<<1|1) }tree[maxn<<2];
建树:
void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) { //初始化 return; } int mid = (l + r) >> 1; build(lson, l, mid); build(rson, mid+1, r); //pushup() }
单点更新:
void change(int p, int x) { if(l(p) == r(p)) { //更新数据 return; } int mid = (l(p) + r(p)) >> 1; if(x <= mid) change(lson, x); else change(rson, x); //pushup() }
询问整体第\(k\)小:
//询问整个区间第k小 //s(p)代表l(p)到r(p)值域中树的个数总和 int query(int p, int k) { if(l(p) == r(p)) return l(p); //由于数组下标维护的是值域,直接返回下标 if(k <= s(lson)) return query(lson, k); //在左子树中 else return query(rson, k - s(lson)); //在右子树中,感觉和平衡树好像 }
例题1:洛谷 https://www.luogu.org/problem/P1801
思路:
- 依题意模拟
代码:
#include<bits/stdc++.h> #include<cstring> using namespace std; typedef long long ll; const int maxn = 2e5 + 10; int a[maxn], num[maxn], u[maxn]; int n, m, len; struct SegmentTree { int l, r; int s; //节点p的s表示这一段值域中数的个数总和 #define l(x) tree[x].l #define r(x) tree[x].r #define s(x) tree[x].s #define lson (p<<1) #define rson (p<<1|1) }tree[maxn<<2]; void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) return; int mid = (l + r) >> 1; build(lson, l, mid); build(rson, mid+1, r); } void change(int p, int x) { if(l(p) == r(p)) { s(p) += 1; return; } int mid = (l(p) + r(p)) >> 1; if(x <= mid) change(lson, x); else change(rson, x); s(p) = s(lson) + s(rson); } //询问整个区间第k大 //s(p)代表l(p)到r(p)值域中树的个数总和 int query(int p, int k) { if(l(p) == r(p)) return l(p); //由于数组下标维护的是值域,直接返回下标 if(k <= s(lson)) return query(lson, k); //在左子树中 else return query(rson, k - s(lson)); //在右子树中,感觉和平衡树好像 } int main() { scanf("%d%d", &m, &n); for(int i = 1; i <= m; i++) { scanf("%d", &a[i]); num[i] = a[i]; } for(int i = 1; i <= n; i++) scanf("%d", &u[i]); sort(num + 1, num + 1 + m); len = unique(num + 1, num + 1 + m) - num - 1; build(1, 1, len); int cnt = 0, k = 0; while(n != cnt) { cnt++; for(int i = u[cnt-1] + 1; i <= u[cnt]; i++) { int y = lower_bound(num+1, num+1+len, a[i]) - num; //y是a(i)在num里的下标 change(1, y); } cout << num[query(1, ++k)] << endl; } return 0; }
例题2:洛谷1908:逆序对(权值线段树写法)
题意描述:
- 求逆序对数目。
思路:
- 见注释
代码:
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 5e5 + 10; ll ans; int n, a[maxn], num[maxn], len; struct SegmentTree { int l, r; ll s; #define l(x) tree[x].l #define r(x) tree[x].r #define lson (p<<1) #define rson (p<<1|1) #define s(x) tree[x].s }tree[maxn<<2]; inline void pushup(int p){ s(p) = s(lson) + s(rson); } inline void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) return; int mid = (l + r) >> 1; build(lson, l, mid); build(rson, mid + 1, r); } inline void change(int p, int x) { if(l(p) == r(p)) { s(p)++; return; } int mid = (l(p) + r(p)) >> 1; if(x <= mid) change(lson, x); else change(rson, x); pushup(p); } ll query(int p, int x) { if(l(p) == r(p)) return s(p); int mid = (l(p) + r(p)) >> 1; if(x <= mid) return query(lson, x) + s(rson); else return query(rson, x); } int main() { scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d", &a[i]); num[i] = a[i]; } build(1, 1, n); sort(num + 1, num + 1 + n); len = unique(num + 1, num + 1 + n) - num - 1; for(int i = 1; i <= n; i++) { int p = lower_bound(num + 1, num + 1 + n, a[i]) - num; a[i] = p; } for(int i = 1; i <= n; i++) //枚举每个a(i)作为右端点 { //看树中有多少比他大的数字 ans += query(1, a[i] + 1); //寻找比当前数大的数字的个数 //+1是因为要过滤掉等于a(i)的 change(1, a[i]); //在权值线段树中加上该节点 } cout << ans << endl; return 0; }
例题3:hdu_4217
题意描述:
- 给定一个\(1\)到\(n\)的序列。每次操作查询序列第\(k\)小的数字加入答案并拿走这个数字,问最后拿走数字的总和是多少
- \(n\leq 3e5\)
#include<bits/stdc++.h> using namespace std; const int maxn = 3e5 + 10; int T, n, m, k, cas; struct SegmentTree { int l, r; int s; #define l(x) tree[x].l #define r(x) tree[x].r #define lson (p<<1) #define rson (p<<1|1) #define s(x) tree[x].s }tree[maxn<<2]; void pushup(int p){ s(p) = s(lson) + s(rson); } void build(int p, int l, int r) { l(p) = l, r(p) = r; if(l == r) {s(p) = 1; return;} int mid = (l + r) >> 1; build(lson, l, mid); build(rson, mid+1, r); pushup(p); } int query(int p, int k) { if(l(p) == r(p)) return l(p); if(k <= s(lson)) return query(lson, k); else return query(rson, k - s(lson)); } void change(int p, int x, int val) { if(l(p) == r(p)) { s(p) = val; return; } int mid = (l(p) + r(p)) >> 1; if(x <= mid) change(lson, x, val); else change(rson, x, val); pushup(p); } int main() { scanf("%d", &T); while(T--) { scanf("%d%d", &n, &m); build(1, 1, n); long long ans = 0; while(m--) { scanf("%d", &k); int num = query(1, k); ans += num; change(1, num, 0); } printf("Case %d: %lld\n", ++cas, ans); } return 0; }