主席树
python版本实现主席树,使用面向对象的方式而不是数组,代码更易懂
题目
关于主席树的几个核心问题:
主席树与线段树的区别:
虽然树节点中都有left和right属性,但是注意,主席树与线段树的left和right含义不同
a
=[1,2,3,4,5,6]
class TreeNode(object):
def __init__(self
):
self
.left
= -1
self
.right
= -1
线段树的left和right存放的是左右端点下标,:
即TreeNode.num(线段树中第i个节点的权值)={ a[left],a[right}
主席树的eft和right存放的是对应的是值而不是下标
即:left<=TreeNode.num(主席树中第i个节点的权值)<=right
主席树与线段树的建树不同
主席树首先建立一棵空树,待空树建立完毕后,遍历数组,将数据依次插入.每次插入数据都会重新创立一个根节点,每个根节点对应一个"版本".每次插入数据都会更新从跟节点到叶子节点中路径的值线段树只需一次递归建树即可.第一次建树即可初始化所有的节点.
主席树是权值线段树
权值线段树
权值线段树就是每个节点都带有权值的树.
线段树节点的权值,实际上就是代表数组中有多少个数落入到该节点中
例如,一列数,n为6,数分别为1 3 2 3 6 1
首先,每棵树都是这样的:(空树)
以第4棵线段树为例,1~4的数分别为1 3 2 3
图片来自主席树
关于离散化:
这里的离散化实际上就是一种数据缩放,并保持原来的大小关系,
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WSfn6LLW-1572508591938)(https://images2017.cnblogs.com/blog/1309909/201801/1309909-20180117194204240-714306710.png)]
(图片摘自博客)
实现步骤:
将原数组排序利用二分搜索,遍历排序后数组,并获得排序后的元素所在的原数组的下标.将2中得到的结果保存起来,即是离散化后的数值,
arr
= [25957, 6405, 15770, 26287, 26465, ]
arr2
= sorted(arr
)
z
= list(map(lambda x
: bisect
.bisect
(arr2
, x
), arr
))
主席树的前缀和思想
所谓的前缀和,在这道题目中就是 求[2,4]中的第k小可以用[0,4]-[0,1]=[2,4]的方法,即用第4棵树减去低1棵树即可,因此任意一个区间的线段树都可以用原有的线段树做差求出.
主席树如何节省空间
在建树中说道,每个根节点对应一个版本,如果每次建树都是完整的建树的话,就会十分的浪费空间.因此引入了版本空值
主席树的版本控制(可持久化)
所谓主席树的版本控制,其实就是不同树之间共享节点,通过访问不同的根节点已达到访问不同的树.
现在举几个例子来说明 序列4 3 2 3 6 1
区间[1,1]的线段树(蓝色节点为新节点)
区间[1,2]的线段树(橙色节点为新节点)
区间[1,3]的线段树(紫色节点为新节点)
图片来自主席树
如何建树:
使用数组来储存元素,这种方法会浪费很多空间,线段树是4倍空间,主席树是32倍空间
使用链表来储存元素:这里使用链表建树
def build(l
, r
):
node
= TreeNode
()
node
.l
= l
node
.r
= r
if l
== r
:
return node
else:
m
= (l
+ r
) >> 1
node_left
= build
(l
, m
)
node_right
= build
(m
+ 1, r
)
node
.left_node
= node_left
node
.right_node
= node_right
return node
如何搜索第k小
首先利用前缀和思想获取范围,然后就像二叉搜索树找第k小一样,
如果左子树的权值大于等于k,说明左子树中包含了第k小的元素如果左子树的权值小k,说明左子树中不包含第k小的元素,要去右子树中找第k-左子树的权值小的元素
代码
面向对象思想存放TreeNode节点递归建树,使用链表bisect库二分搜索,copy库进行节点复制有几个用例过不了,超时,可能是python语言的问题
import bisect
import copy
class TreeNode(object):
def __init__(self
):
self
.left_node
= None
self
.right_node
= None
self
.num
= 0
self
.l
= -1
self
.r
= -1
def __str__(self
):
return '[%s,%s,] num:%s,' % (self
.l
, self
.r
, self
.num
)
def _show_arr(self
, node
, ):
print(node
)
if node
.l
== node
.r
:
return
else:
self
._show_arr
(node
.left_node
)
self
._show_arr
(node
.right_node
)
def show_arr(self
, ):
self
._show_arr
(self
)
def show_diff(self
, node2
):
self
._show_diff
(self
, node2
)
def _show_diff(self
, node
, node2
):
print(node
.l
, node
.r
, node
.num
- node2
.num
)
if node
.l
== node
.r
:
return
else:
self
._show_diff
(node
.left_node
, node2
.left_node
)
self
._show_diff
(node
.right_node
, node2
.right_node
)
def build(l
, r
):
node
= TreeNode
()
node
.l
= l
node
.r
= r
if l
== r
:
return node
else:
m
= (l
+ r
) >> 1
node_left
= build
(l
, m
)
node_right
= build
(m
+ 1, r
)
node
.left_node
= node_left
node
.right_node
= node_right
return node
def insert(x
, node
: TreeNode
):
node
.num
+= 1
if node
.l
== node
.r
:
return
m
= (node
.l
+ node
.r
) >> 1
if m
>= x
:
left_node
= copy
.copy
(node
.left_node
)
node
.left_node
= left_node
insert
(x
, node
.left_node
)
if m
< x
:
right_node
= copy
.copy
(node
.right_node
)
node
.right_node
= right_node
insert
(x
, node
.right_node
)
def find_k(nl
: TreeNode
, nr
: TreeNode
, k
):
if nr
.l
== nr
.r
:
return nr
.l
left_num_diff
= nr
.left_node
.num
- nl
.left_node
.num
if k
<= left_num_diff
:
return find_k
(nl
.left_node
, nr
.left_node
, k
)
else:
return find_k
(nl
.right_node
, nr
.right_node
, k
- left_num_diff
)
def test():
arr
= [25957, 6405, 15770, 26287, 26465, ]
arr2
= sorted(arr
)
z
= list(map(lambda x
: bisect
.bisect
(arr2
, x
), arr
))
n
= build
(1, len(z
))
rt
= []
rt
.append
(n
)
for x
in z
:
n2
= copy
.copy
(rt
[-1])
insert
(x
, n2
)
rt
.append
(n2
)
res
= find_k
(rt
[1], rt
[2], 1)
print(res
)
print(arr2
[res
- 1])
res
= find_k
(rt
[0], rt
[2], 2)
print(res
)
print(arr2
[res
- 1])
res
= find_k
(rt
[3], rt
[5], 1)
print(res
)
print(arr2
[res
-1])
if __name__
== '__main__':
line1
= [int(x
) for x
in input().strip
().split
(" ")]
n
= line1
[0]
m
= line1
[1]
arr
= [int(x
) for x
in input().strip
().split
(" ")]
arr2
= sorted(arr
)
z
= list(map(lambda x
: bisect
.bisect
(arr2
, x
), arr
))
rt
= build
(1, len(z
))
rt_arr
= [rt
]
for x
in z
:
rt_temp
= copy
.copy
(rt_arr
[-1])
insert
(x
, rt_temp
)
rt_arr
.append
(rt_temp
)
for i
in range(m
):
line
= [int(x
) for x
in input().split
(" ")]
res
= find_k
(rt_arr
[line
[0] - 1], rt_arr
[line
[1]], line
[2])
print(arr2
[res
-1])
参考博客
主席树
主席树