主席树笔记
- By BigYellowDog
- 前置芝士:线段树、前缀和、最好还有平衡树。
导入
- 主席树是什么?它可以干嘛?为什么要用它?
- 主席树这个名字跟其功能没有关系。名字来源据说是hjt大牛发明的,根据其网名命名的。
- 主席树真正的名字叫可持久化线段树。
- 假设现在有这么一个问题:现有一个序列,q次询问。每次问区间[l, r]里的最大值。
- 很简单啊,线段树 / st表就好。
- 那改一下,每次询问区间[l, r]里的次大值。
- 很简单啊,线段树就好,只不过多维护一个值而已。
- 那再改一下,每次询问区间[l, r]里的第k大值。
- ... ...
- emmm反正我是写不出来了。
那么主席树就是解决求区间第k大此类问题的。
当然,区间第k大还可以用整体二分。但主席树易理解,码量也不大。所以是一种性价比很高的算法。
原理
- 现在有一个序列3 5 4 1 2,请找出区间[1, 4]的第2大。
ok现在来模拟主席树算法过程。
建出一个空树(树0),如下图。
棕色数字表示节点的值。
- (1~5)这个节点表示1 <= 值 <= 5的个数
- (1~3)这个节点表示1 <= 值 <= 3的个数
... ...注意这个一颗权值线段树。点的左右端点表示权值的边界。

- 插入序列中第一个数3,图就变成了(树1)

- 插入序列中第二个数5,图就变成了(树2)

- 一直插完序列中第4个数1后,图就变成了(树4)

- 到这里打住,我们已经插入了前4个数,那么就可以查询区间[1, 4]的第k大了。
- 首先进入(1~5)节点,发现其左儿子的子树个数2 <= k(2),于是进入(1~3)节点
- 发现(1~3)节点的左儿子子树个数1 < k(2),于是进入(3~3)节点,k更新为2 - 1 = 1
- 走到头了,于是返回节点的左端点3。于是区间[1, 4]的第2小就是3
- 为什么返回左端点?因为这是一颗权值线段树。
- 上述在权值线段树里找第k大的过程就是平衡树找第k大的过程
- 那现在,请找出区间[2, 4]里的第2大值。
- 很容易啊,我们取出刚刚建出的树4和树1。拿树4 减 树1即可得到一个新树。
- 减,即拿每个对应节点相减。
- 在这个新树上找第2大即可。
- 这就是一个前缀和啊。
- 那利用主席树解决区间第k大的方法就出来了。
- 依次插入序列中的数,每插入一个树建立一个新树。
- 若查询区间[l, r],则拿出树r和树(l - 1)相减得到新树,在新树上找第k大。
实现
- 原理就是这样啦,但是如果每插入一个树就建立一颗新树,那岂不是空间爆炸?
- 是的,原理很简单,关键就是实现,这也是发明者的精妙之处。
- 首先可以发现一个性质,每插入一个数,相对于上一棵树来说,只会有从根节点到叶节点的一条链上的点的值发生了变化。其它都是不变的!
- 如图,这是上述序列插入3后的图,当插入4后只有蓝色路径上的节点的值会++

- 那我们每建立一颗新树就不用重新建了,而是值改变的点就重建,没改变的连到上一个树上去即可。那么上述的图可以变成

- 带 ' 的表示是插入4后的,不带 ' 的是插入3后的树。
- 这样建树的空间开销就大大减少了。
代码
- 搬来一道模板题
cin >> n >> m; for(int i = 1; i <= n; i++) a[i] = read(), b[++cnt] = a[i]; sort(b + 1, b + 1 + cnt); cnt = unique(b + 1, b + 1 + cnt) - b - 1; /** * 首先输入序列中的每一个数,因为权值很大,我们又要按照权值建树,那么就先离散化下 * a是原序列,b是离散化后的数组,cnt是离散化的不同权值个数 */
r[0] = build(1, cnt); //建一个空树, r[0]表示第0棵树的根节点 int build(int l, int r) //建树函数,各位都懂 { int p = ++dex, mid = l + r >> 1; if(l == r) return p; t[p].l = build(l, mid); t[p].r = build(mid + 1, r); return p; }
for(int i = 1; i <= n; i++) //这里就是每插入一个数建一个树的过程了 r[i] = upd(r[i - 1], 1, cnt, find(a[i])); int upd(int las, int l, int r, int val) { int p = ++dex, mid = l + r >> 1; t[p].l = t[las].l, t[p].r = t[las].r; //首先都给它连上上一棵树的节点 t[p].sum = t[las].sum + 1; if(l == r) return p; if(val <= mid) t[p].l = upd(t[las].l, l, mid, val); //说明左子树发生了改变 else t[p].r = upd(t[las].r, mid + 1, r, val); //说明右子树发生了改变 return p; }
for(int i = 1; i <= m; i++) { int ll = read(), rr = read(), rank = read(); printf("%d\n", b[ask(r[ll - 1], r[rr], 1, cnt, rank)]); } int ask(int u, int v, int l, int r, int rank) //取出了树u和树v { if(l == r) return l; int size = t[t[v].l].sum - t[t[u].l].sum, mid = l + r >> 1; //得到新树的左儿子子树个数 if(rank <= size) return ask(t[u].l, t[v].l, l, mid, rank); else return ask(t[u].r, t[v].r, mid + 1, r, rank - size); }
- 最后是完整代码:链接