线段树

邮差的信 提交于 2020-01-20 22:15:02

1. 概念

1.1 定义

线段树使用一个完全二叉树来存储每个区间(segment) 的数据。线段树所使用的二叉树是用一个数组保存的。

完全二叉树:除了最后一层之外的其他每一层都被完全填充,并且所有结点都保持向左对。参考 https://www.zhihu.com/question/19809666/answer/13029983

对于长度为 n 的线段数组有如下性质:

  • 树的高度 log(n)

  • 节点的个数 2n -1「 (n) 叶子节点 +( n -1 ) 内部节点」 <=> 2Log2n+12^{⌈ Log_2^n ⌉ + 1} -1

    根据完全二叉树的性质,树高 h = ⌈Log2nLog_2^n ⌉ +1 ,节点个数 n = 2h2^h - 1

    由此节点个数 2Log2n+12^{⌈ Log_2^n ⌉ + 1} -1 => 22Log2n2*2^{⌈ Log_2^n ⌉} -1

    参考:https://stackoverflow.com/questions/28470692/how-is-the-memory-of-the-array-of-segment-tree-2-2-ceillogn-1

1.2 使用场景

  • 更新
    • 指定区间的某个值
    • 更新某个区间的所有值
  • 查询指定区间的统计信息
    • 求最值
    • 区间和

这么写使用场景对于一开始研究的我来说,有一种「雾里看花」的感觉。为了更好的学习这种数据结构而不是纸上谈兵,我想了一个使用的例子。

过年各家 APP 都会绞尽脑汁的出来年度账单。就以年度账单为例,套用下上述场景的使用方便理解。假设最近你想分析一下你一年十二个月的消费信息。

  • 作为一个细心的分析者,你不仅需要知道总的消费金额,你还想知道每个季度消费金额,或者指定范围月份的消费金额 ==> 区间和

  • 你还想要知道一年之哪个月份消费最高或者哪个月份消费最低 ==> 求最值

  • 分析银行的消费记录,你发现某个月份的消费记录记录少了,要重新更正下 ==> 更新指定区间的某个值

1.2.1 时间复杂度

  • 构建线段树的时间复杂度 O(n)
  • 最值查询,更新指定值,区间求和的时间复杂度 O(log(n))

根据线段树的使用场景不同,在实际构建的过程中也会使用不同的构建策略,下面会分别讲述。

2. 线段树——区间求和及更新

「本着一图胜千言」的策略,请看图 + 代码理解。当然这里的例子举得是非线段化的,不过可以脑补成 「1,1」,「3,3」 这种,如果按照区间格式化,会比较难理解(ps 主要是画多个节点更累

在这里插入图片描述

go 线段树实现:(ps 原谅我这么 low 的命名吧

package segmenttree

import (
	"math"
)

// SegmentTree1 implement segment tree, support sum and update.
type SegmentTree1 struct {
	st   []int
	size int
}

// NewSegmentTree1 used to generate segment tree.
func NewSegmentTree1(arr []int) *SegmentTree1 {
	n := len(arr)
	// height of segment tree,
	height := int(math.Ceil(math.Log(float64(n)) / math.Log(float64(2))))
	nodeSize := 2*int(math.Pow(2, float64(height))) - 1
	st := make([]int, nodeSize)
	s := &SegmentTree1{
		st:   st,
		size: n,
	}
	s.newUtil(arr, 0, n-1, 0)
	return s
}

func (s *SegmentTree1) newUtil(arr []int, ss int, se, si int) int {
	if ss == se {
		s.st[si] = arr[ss]
		return arr[ss]
	}
	mid := getMid(ss, se)
	s.st[si] = s.newUtil(arr, ss, mid, si*2+1) + s.newUtil(arr, mid+1, se, si*2+2)
	return s.st[si]
}

func (s *SegmentTree1) sumUtil(ss, se, qs, qe, si int) int {
	if qs <= ss && qe >= se {
		return s.st[si]
	}
	if se < qs || ss > qe {
		return 0
	}
	mid := getMid(ss, se)
	return s.sumUtil(ss, mid, qs, qe, 2*si+1) + s.sumUtil(mid+1, se, qs, qe, 2*si+2)

}

func (s *SegmentTree1) getSum(qs, qe int) int {
	if qs < 0 || qe > s.size-1 || qs > qe {
		return -1
	}
	return s.sumUtil(0, s.size-1, qs, qe, 0)
}

func (s *SegmentTree1) addUtil(ss, se, i, value, si int) {
	if i < ss || i > se {
		return
	}
	s.st[si] = s.st[si] + value
	if se != ss {
		mid := getMid(ss, se)
		s.addUtil(ss, mid, i, value, 2*si+1)
		s.addUtil(mid+1, se, i, value, 2*si+2)
	}

}

func (s *SegmentTree1) addValue(i, value int) {
	if i < 0 || i > s.size-1 {
		return
	}
	s.addUtil(0, s.size-1, i, value, 0)
}

3. 线段树——区间最值

以下以最小值为例
在这里插入图片描述
go 实现:(ps 命名就是这么随意,没有惊喜,没有意外

package segmenttree

import "math"

// SegmentTree2 implement segment, support mix
type SegmentTree2 struct {
	st   []int
	size int
}

// NewSegmentTree2 used to new a segment tree instance.
func NewSegmentTree2(arr []int) *SegmentTree2 {
	n := len(arr)
	// height of segment tree,
	height := int(math.Ceil(math.Log(float64(n)) / math.Log(float64(2))))
	nodeSize := 2*int(math.Pow(2, float64(height))) - 1
	st := make([]int, nodeSize)
	s := &SegmentTree2{
		st:   st,
		size: n,
	}
	s.newUtil(arr, 0, n-1, 0)
	return s
}

func (s *SegmentTree2) newUtil(arr []int, ss, se, si int) int {
	if ss == se {
		s.st[si] = arr[ss]
		return arr[ss]
	}
	mid := getMid(ss, se)
	s.st[si] = min(s.newUtil(arr, ss, mid, si*2+1), s.newUtil(arr, mid+1, se, si*2+2))
	return s.st[si]
}

func (s *SegmentTree2) findMinUtil(ss, se, qs, qe, index int) int {
	// segment of this node is a part of give range
	if qs <= ss && qe >= se {
		return s.st[index]
	}
	// outside the give range
	if se < qs || ss > qe {
		return math.MaxInt16
	}
	mid := getMid(ss, se)
	return min(s.findMinUtil(ss, mid, qs, qe, 2*index+1), s.findMinUtil(mid+1, se, qs, qe, 2*index+2))
}

func (s *SegmentTree2) findMin(qs, qe int) int {
	if qs < 0 || qe > s.size-1 || qs > qe {
		return -1
	}
	return s.findMinUtil(0, s.size-1, qs, qe, 0)
}

4. 碎碎念

完整代码实现在:

还有一种比较高效的实现方式,请查参考资料。最近真的是越来越拖延了……,各种要写的都木有写,自己都不能忍了的那种,希望春节后能有所改善。

5. 参考资料

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!