题目: 定义二叉树节点:
class TreeNode(): def __init__(self, value): self._left = None self._right = None self._value = value def get_left(self): return self._left def set_left(self, left): self._left = left def get_right(self): return self._right def set_right(self, right): self._right = right def get_value(self): return self._value定义数组的MaxTree: 1)数组没有重复元素 2)MaxTree是一棵二叉树,数组每个值对应一个节点 3)MaxTree及没一棵子树,满足值最大的值都是根节点
给定一个没有重复元素的数组,生成MaxTree,要求时间复杂度O(n), 空间复杂度O(n)。
思路: 1.看到这个题目,首先想到的是“分而治之”: 1)找到数组的最大值,作为根节点 2)根节点左边的元素生成左子树,右边的元素生成右子树
def get_max_tree_rec(arr): if len(arr) == 0: return None m = 0 for i in range(1, len(arr)): if arr[i] > arr[m]: m = i root = TreeNode(arr[m]) root.set_left(get_max_tree_rec(arr[:m])) root.set_right(get_max_tree_rec(arr[m+1:])) return root2.当然,可以利用python的库函数,使代码更加简练:
def get_max_tree_rec2(arr): if len(arr) == 0: return None m = arr.index(max(arr)) root = TreeNode(arr[m]) root.set_left(get_max_tree_rec(arr[:m])) root.set_right(get_max_tree_rec(arr[m+1:])) return root3.书中的解法:
def pop_stack_and_set_bigger_node(stack, bigger_node_map): top = stack.pop() if len(stack) == 0: bigger_node_map[top] = None else: bigger_node_map[top] = stack[-1] def get_max_tree(arr): nodes = [ TreeNode(arr[i]) for i in range(len(arr)) ] stack = [] left_first_bigger_node_map = {} right_first_bigger_node_map = {} for i in range(len(arr)): cur = nodes[i] while len(stack) and stack[-1].get_value() < cur.get_value(): pop_stack_and_set_bigger_node(stack, left_first_bigger_node_map) stack.append(cur) while len(stack): pop_stack_and_set_bigger_node(stack, left_first_bigger_node_map) # for i in range(len(arr)-1, -1, -1): cur = nodes[i] while len(stack) and stack[-1].get_value() < cur.get_value(): pop_stack_and_set_bigger_node(stack, right_first_bigger_node_map) stack.append(cur) while len(stack): pop_stack_and_set_bigger_node(stack, right_first_bigger_node_map) root = None for i in range(len(arr)): cur = nodes[i] left = left_first_bigger_node_map[cur] right = right_first_bigger_node_map[cur] if left is None and right is None: root = nodes[i] elif left is None: if right.get_left() is None: right.set_left(cur) else: right.set_right(cur) elif right is None: if left.get_left() is None: left.set_left(cur) else: left.set_right(cur) else: parent = left if left.get_value() < right.get_value() else right if parent.get_left() is None: parent.set_left(cur) else: parent.set_right(cur) return root4.化简一下,找第一个大于当前值节点的代码提取成函数,生成树过程简化一下:
def get_first_bigger_nodes(nodes, bigger_node_map): stack = [] for i in range(len(nodes)): cur = nodes[i] while len(stack) and stack[-1].get_value() < cur.get_value(): pop_stack_and_set_bigger_node(stack, bigger_node_map) stack.append(cur) while len(stack): pop_stack_and_set_bigger_node(stack, bigger_node_map) def get_max_tree3(arr): nodes = [ TreeNode(arr[i]) for i in range(len(arr)) ] left_first_bigger_node_map = {} right_first_bigger_node_map = {} get_first_bigger_nodes(nodes, left_first_bigger_node_map) get_first_bigger_nodes(list(reversed(nodes)), right_first_bigger_node_map) root = None for i in range(len(arr)): cur = nodes[i] left = left_first_bigger_node_map[cur] right = right_first_bigger_node_map[cur] if left is None and right is None: root = nodes[i] else: parent = None if left is None: parent = right elif right is None: parent = left else: parent = left if left.get_value() < right.get_value() else right if parent.get_left() is None: parent.set_left(cur) else: parent.set_right(cur) return root5.记录索引,而不是node,并使用list替代dir
def get_max_tree5(arr): def get_left_first_bigger_index(): stack = [] for i in range(len(arr)): cur = arr[i] while len(stack) and arr[stack[-1]] < cur: top = stack.pop() if len(stack): left_first_bigger_index_map[top] = stack[-1] stack.append(i) while len(stack): top = stack.pop() if len(stack): left_first_bigger_index_map[top] = stack[-1] def get_right_first_bigger_index(): stack = [] for i in range(len(arr) - 1, -1, -1): cur = arr[i] while len(stack) and arr[stack[-1]] < cur: top = stack.pop() if len(stack): right_first_bigger_index_map[top] = stack[-1] stack.append(i) while len(stack): top = stack.pop() if len(stack): right_first_bigger_index_map[top] = stack[-1] nodes = [ TreeNode(arr[i]) for i in range(len(arr)) ] left_first_bigger_index_map = [-1 for _ in range(len(arr))] right_first_bigger_index_map = [-1 for _ in range(len(arr))] get_left_first_bigger_index() get_right_first_bigger_index() root = None parent_index = -1 for i in range(len(arr)): left = left_first_bigger_index_map[i] right = right_first_bigger_index_map[i] cur = nodes[i] if left == -1 and right == -1: root = nodes[i] else: if left == -1: parent_index = right elif right == -1: parent_index = left else: parent_index = left if arr[left] < arr[right] else right if parent_index > i: nodes[parent_index].set_left(cur) else: nodes[parent_index].set_right(cur) return root6.更简单的找第一个比自己大的索引的方法:
def get_max_tree6(arr): def get_left_first_bigger_index(): stack = [] result = [-1 for _ in range(len(arr))] for i in range(len(arr)): while len(stack) and arr[i] > arr[stack[-1]]: stack.pop() if len(stack): result[i] = stack[-1] stack.append(i) return result def get_right_first_bigger_index(): stack = [] result = [-1 for _ in range(len(arr))] for i in range(len(arr) - 1, -1, -1): while len(stack) and arr[i] > arr[stack[-1]]: stack.pop() if len(stack): result[i] = stack[-1] stack.append(i) return result nodes = [ TreeNode(arr[i]) for i in range(len(arr)) ] left_first_bigger_index_map = get_left_first_bigger_index() right_first_bigger_index_map = get_right_first_bigger_index() root = None parent_index = -1 for i in range(len(arr)): left = left_first_bigger_index_map[i] right = right_first_bigger_index_map[i] cur = nodes[i] if left == -1 and right == -1: root = nodes[i] else: if left == -1: parent_index = right elif right == -1: parent_index = left else: parent_index = left if arr[left] < arr[right] else right if parent_index > i: nodes[parent_index].set_left(cur) else: nodes[parent_index].set_right(cur) return root7.合二为一,同时找左右比自己大的数,若两边都有,选那个小的
def get_max_tree7(arr): def get_min_bigger_index(): stack = [] result = [-1 for _ in range(len(arr))] for i in range(len(arr)): while len(stack) and arr[i] > arr[stack[-1]]: top = stack.pop() if result[top] == -1 or arr[i] < arr[result[top]]: result[top] = i if len(stack): result[i] = stack[-1] stack.append(i) return result nodes = [ TreeNode(arr[i]) for i in range(len(arr)) ] min_bigger_index = get_min_bigger_index() root = None for i in range(len(arr)): cur_node = nodes[i] index = min_bigger_index[i] if index == -1: root = cur_node else: pnode = nodes[index] if index > i: pnode.set_left(cur_node) else: pnode.set_right(cur_node) return root测试:
from functools import wraps def timethis(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() print(func.__name__, 'cost time:', end - start) return result return wrapper def traverse_inorder(root, result): if root is None: return traverse_inorder(root.get_left(), result) result.append(root.get_value()) traverse_inorder(root.get_right(), result) def traverse_preorder(root, result): if root is None: return result.append(root.get_value()) traverse_preorder(root.get_left(), result) traverse_preorder(root.get_right(), result) @timethis def test_get_max_tree_rec(arr): root = get_max_tree_rec(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree(arr): root = get_max_tree(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree3(arr): root = get_max_tree3(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree_rec2(arr): root = get_max_tree_rec2(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree5(arr): root = get_max_tree5(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree6(arr): root = get_max_tree6(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result @timethis def test_get_max_tree7(arr): root = get_max_tree7(arr) time_end = time.time() result = [] traverse_preorder(root, result) return result def test1(count): arr = [] for i in range(count): arr.append(i) random.shuffle(arr) test_get_max_tree_rec(arr) result1 = test_get_max_tree(arr) result3 = test_get_max_tree3(arr) result4 = test_get_max_tree_rec(arr) result5 = test_get_max_tree_rec2(arr) result7 = test_get_max_tree5(arr) result8 = test_get_max_tree6(arr) result9 = test_get_max_tree7(arr) for i in range(len(result1)): if result1[i] != result3[i]: raise Exception('Error 2') if result3[i] != result4[i]: raise Exception('Error 4') if result4[i] != result5[i]: raise Exception('Error 5') if result5[i] != result7[i]: raise Exception('Error 6') if result7[i] != result8[i]: raise Exception('Error 8') if result8[i] != result9[i]: raise Exception('Error 9') if __name__ == '__main__': test1(100000)结果:
test_get_max_tree_rec cost time: 0.35343265533447266 test_get_max_tree cost time: 0.4225938320159912 test_get_max_tree3 cost time: 0.42000365257263184 test_get_max_tree_rec cost time: 0.390531063079834 test_get_max_tree_rec2 cost time: 0.3676440715789795 test_get_max_tree5 cost time: 0.2665879726409912 test_get_max_tree6 cost time: 0.2726738452911377 test_get_max_tree7 cost time: 0.2191329002380371