线段树(区间树) 对于给定区间 更新:更新区间中一个元素或者一个区间的值 查询:查询一个区间[i,j]的最大值,最小值或者区间数字和 线段树是平衡二叉树(对于整棵树来说,最大深度和最小深度相差不超过一)(堆也是平衡二叉树) “ 如果区间有n个元素,数组表示需要有多少节点? 0层:1 1层:2 2层:4 3层:8 … h-1层:2^(h-1) ” 对满二叉树: h层,一共有2h-1个节点(大约是2h) 最后一层(h-1层),有2^(h-1)个节点 最后一层的节点数大致等于前面所有层节点之和 需要大概4n的空间(可能会浪费一些空间) 代码实现 接口:
public interface Merger<E>{ E merger(E a, E b); } public class SegmentTree<E> { private E[] tree; private E[] data; private Merger<E> merger; public SegmentTree (E[] arr, Merger<E> merger){ this.merger = merger; data = (E[])new Object[arr.length]; for (int i = 0; i < arr.length; i ++){ data[i] = arr[i]; } tree = (E[])new Object[4 * arr.length]; buildSegmentTree(0, 0, data.length - 1); } //在treeIndex的位置创建表示区间[l...r] private void buildSegmentTree(int treeIndex, int l, int r){ if (l == r){ tree[treeIndex] = data[l]; return; } int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); int mid = l + (r - l) / 2; buildSegmentTree(leftTreeIndex, l, mid); buildSegmentTree(rightTreeIndex, mid + 1, r); tree[treeIndex] = merger.merger(tree[leftTreeIndex], tree[rightTreeIndex]); } public int getSize(){ return data.length; } public E get (int index){ if (index < 0 || index >= data.length){ throw new IllegalArgumentException("Index is illegal."); } return data[index]; } //返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引 private int leftChild(int index){ return 2 * index + 1; } //返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引 private int rightChild(int index){ return 2 * index + 2; } //返回区间[queryL, queryR]的值 public E query(int queryL, int queryR){ if (queryL < 0 || queryL >=data.length || queryR < 0 || queryR >= data.length || queryL > queryR) throw new IllegalArgumentException("Index is length."); return query(0, 0, data.length - 1, queryL, queryR); } //在以treeID为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值 private E query (int treeIndex, int l, int r, int queryL, int queryR){ if (l == queryL && r == queryR) return tree[treeIndex]; int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) return query(rightTreeIndex, mid + 1, r, queryL, queryR); else if (queryR <= mid) return query(leftTreeIndex, l, mid, queryL, queryR); E leftResult = query(leftTreeIndex, l, mid, queryL, mid); E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1 ,queryR); return merger.merger(leftResult, rightResult); } @Override public String toString(){ StringBuilder res = new StringBuilder(); res.append('['); for (int i = 0; i < tree.length; i ++){ if (tree[i] != null){ res.append(tree[i]); }else{ res.append("null"); } if(i == tree.length - 1){ res.append(']'); } } return res.toString(); } //将index位置的值,更新为e public void set(int index, E e){ if (index < 0 || index >= data.length) throw new IllegalArgumentException("Index is illegal"); data[index] = e; set(0, 0, data.length - 1, index, e); } // 在以treeIndex为根的线段树中更新index的值为e private void set(int treeIndex, int l, int r, int index, E e){ if(l == r){ tree[index] = e; return; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (index >= mid + 1) set(rightTreeIndex, mid + 1, r, index, e); else set(leftTreeIndex, l, mid, index, e); tree[treeIndex] = merger.merger(tree[leftTreeIndex], tree[rightTreeIndex]); } @Override public String toString(){ StringBuilder res = new StringBuilder(); res.add('['); for (int i = 0; i < tree.length; i ++){ if (tree[i] != null){ res.append(tree[i]); }else{ res.append("null"); } if(i != tree.length - 1){ res.append(']'); } } } } Main函数 public class Main{ public static void main(String[] args) { Integer[] nums = {-2, 0, 3, -5, 2, -1}; SegmentTree<Integer> segTree = new SegmentTree<>(nums, (a, b) -> a + b); System.out.println(segTree); // System.out.println(segTree.query(0,2)); } }