[算法]主席树区间第k小(python)

mac2024-05-22  24

主席树

python版本实现主席树,使用面向对象的方式而不是数组,代码更易懂

题目

关于主席树的几个核心问题:

主席树与线段树的区别:

虽然树节点中都有left和right属性,但是注意,主席树与线段树的left和right含义不同

a =[1,2,3,4,5,6] class TreeNode(object): def __init__(self): self.left = -1 self.right = -1

线段树的left和right存放的是左右端点下标,:

即TreeNode.num(线段树中第i个节点的权值)={ a[left],a[right}

主席树的eft和right存放的是对应的是值而不是下标

即:left<=TreeNode.num(主席树中第i个节点的权值)<=right

主席树与线段树的建树不同

主席树首先建立一棵空树,待空树建立完毕后,遍历数组,将数据依次插入.每次插入数据都会重新创立一个根节点,每个根节点对应一个"版本".每次插入数据都会更新从跟节点到叶子节点中路径的值线段树只需一次递归建树即可.第一次建树即可初始化所有的节点.

主席树是权值线段树

权值线段树

权值线段树就是每个节点都带有权值的树.

线段树节点的权值,实际上就是代表数组中有多少个数落入到该节点中

例如,一列数,n为6,数分别为1 3 2 3 6 1

首先,每棵树都是这样的:(空树)

以第4棵线段树为例,1~4的数分别为1 3 2 3

图片来自主席树

关于离散化:

这里的离散化实际上就是一种数据缩放,并保持原来的大小关系,

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WSfn6LLW-1572508591938)(https://images2017.cnblogs.com/blog/1309909/201801/1309909-20180117194204240-714306710.png)]

(图片摘自博客)

实现步骤:

将原数组排序利用二分搜索,遍历排序后数组,并获得排序后的元素所在的原数组的下标.将2中得到的结果保存起来,即是离散化后的数值, arr = [25957, 6405, 15770, 26287, 26465, ] arr2 = sorted(arr) # 排序 [6405, 15770,25957 ,26287,26465] z = list(map(lambda x: bisect.bisect(arr2, x), arr)) # [3, 1, 2, 4, 5]
主席树的前缀和思想

所谓的前缀和,在这道题目中就是 求[2,4]中的第k小可以用[0,4]-[0,1]=[2,4]的方法,即用第4棵树减去低1棵树即可,因此任意一个区间的线段树都可以用原有的线段树做差求出.

主席树如何节省空间

在建树中说道,每个根节点对应一个版本,如果每次建树都是完整的建树的话,就会十分的浪费空间.因此引入了版本空值

主席树的版本控制(可持久化)

所谓主席树的版本控制,其实就是不同树之间共享节点,通过访问不同的根节点已达到访问不同的树.

现在举几个例子来说明 序列4 3 2 3 6 1

区间[1,1]的线段树(蓝色节点为新节点)

区间[1,2]的线段树(橙色节点为新节点)

区间[1,3]的线段树(紫色节点为新节点)

图片来自主席树

如何建树:

使用数组来储存元素,这种方法会浪费很多空间,线段树是4倍空间,主席树是32倍空间

使用链表来储存元素:这里使用链表建树

# 递归建一棵空树 def build(l, r): node = TreeNode() node.l = l node.r = r if l == r: return node else: m = (l + r) >> 1 node_left = build(l, m) node_right = build(m + 1, r) node.left_node = node_left node.right_node = node_right return node
如何搜索第k小

首先利用前缀和思想获取范围,然后就像二叉搜索树找第k小一样,

如果左子树的权值大于等于k,说明左子树中包含了第k小的元素如果左子树的权值小k,说明左子树中不包含第k小的元素,要去右子树中找第k-左子树的权值小的元素

代码

面向对象思想存放TreeNode节点递归建树,使用链表bisect库二分搜索,copy库进行节点复制有几个用例过不了,超时,可能是python语言的问题 import bisect import copy class TreeNode(object): def __init__(self): self.left_node = None self.right_node = None self.num = 0 self.l = -1 self.r = -1 # 打印函数 def __str__(self): # return '[%s,%s,] num:%s, %s' % (self.l, self.r, self.num, id(self)) # 查看地址,确实新建了部分节点 return '[%s,%s,] num:%s,' % (self.l, self.r, self.num) # 打印当前树形结构 def _show_arr(self, node, ): print(node) if node.l == node.r: return else: self._show_arr(node.left_node) self._show_arr(node.right_node) def show_arr(self, ): self._show_arr(self) # 打印区间求差之后的树形结构 def show_diff(self, node2): self._show_diff(self, node2) def _show_diff(self, node, node2): print(node.l, node.r, node.num - node2.num) if node.l == node.r: return else: self._show_diff(node.left_node, node2.left_node) self._show_diff(node.right_node, node2.right_node) # sum数组:记录节点权值 # p:记录离散化后序列长度,也是线段树的区间最大长度 # 递归建一棵空树 def build(l, r): node = TreeNode() node.l = l node.r = r if l == r: return node else: m = (l + r) >> 1 node_left = build(l, m) node_right = build(m + 1, r) node.left_node = node_left node.right_node = node_right return node def insert(x, node: TreeNode): node.num += 1 if node.l == node.r: # 已经到了子节点了 return m = (node.l + node.r) >> 1 if m >= x: # 左子树的最大值大于了该值,搜索左子树 left_node = copy.copy(node.left_node) # 复制一份节点 node.left_node = left_node insert(x, node.left_node) if m < x: # 右子树的最小值小于该值 right_node = copy.copy(node.right_node) # 复制一份节点 node.right_node = right_node insert(x, node.right_node) def find_k(nl: TreeNode, nr: TreeNode, k): if nr.l == nr.r: return nr.l left_num_diff = nr.left_node.num - nl.left_node.num if k <= left_num_diff: return find_k(nl.left_node, nr.left_node, k) else: return find_k(nl.right_node, nr.right_node, k - left_num_diff) # 落谷用例 def test(): arr = [25957, 6405, 15770, 26287, 26465, ] arr2 = sorted(arr) # 排序 [6405, 15770,25957 ,26287,26465] z = list(map(lambda x: bisect.bisect(arr2, x), arr)) # [3, 1, 2, 4, 5] n = build(1, len(z)) rt = [] rt.append(n) for x in z: n2 = copy.copy(rt[-1]) # 复制最后一个版本的树 insert(x, n2) # 将值添加进去 rt.append(n2) # n2.show_arr() # print() # 2 2 1 res = find_k(rt[1], rt[2], 1) print(res) print(arr2[res - 1]) # 1 2 2 res = find_k(rt[0], rt[2], 2) print(res) print(arr2[res - 1]) # 4 4 1 res = find_k(rt[3], rt[5], 1) print(res) print(arr2[res-1]) if __name__ == '__main__': # test() line1 = [int(x) for x in input().strip().split(" ")] n = line1[0] # 数字的个数 m = line1[1] # 查询的个数 arr = [int(x) for x in input().strip().split(" ")] # 离散化 arr2 = sorted(arr) # 排序 z = list(map(lambda x: bisect.bisect(arr2, x), arr)) rt = build(1, len(z)) rt_arr = [rt] for x in z: rt_temp = copy.copy(rt_arr[-1]) # 复制最后一个版本的树 insert(x, rt_temp) # 将值添加进去 rt_arr.append(rt_temp) for i in range(m): line = [int(x) for x in input().split(" ")] res = find_k(rt_arr[line[0] - 1], rt_arr[line[1]], line[2]) print(arr2[res-1])

参考博客

主席树

主席树

最新回复(0)