模版题-数据结构-线段树

管理员
## 线段树详解 ### 核心概念 线段树是一种二叉树数据结构,用于高效处理区间查询和区间更新操作。 ### 基本结构 ``` 区间 [0, 7] 的线段树: [0,7] (根节点) / \ [0,3] [4,7] / \ / \ [0,1] [2,3] [4,5] [6,7] / \ / \ / \ / \ [0][1][2][3][4][5][6][7] ``` 特点: - 每个节点代表一个区间 - 叶子节点代表单个元素 - 父节点区间 = 左子节点 + 右子节点 ### 为什么需要线段树 **问题场景** 给定数组 nums = [1, 3, 5, 7, 9, 11, 13, 15],需要频繁执行: - 区间查询:求 [2, 5] 的最大值 - 区间更新:将 [3, 6] 的元素都加 10 ``` 暴力方法 vs 线段树 操作 暴力方法 线段树 查询 [2, 5] 最大值 O(n) 遍历 O(log n) 更新 [3, 6] 元素 O(n) 遍历 O(log n) 预处理 O(1) O(n) ``` ### 线段树实现 1. 树的存储 ``` # 使用数组存储(完全二叉树) tree = [0] * (4 * n) # 4n 空间足够 # 节点索引关系: # - 父节点: i // 2 # - 左子节点: i * 2 # - 右子节点: i * 2 + 1 ``` 2. 建树 ``` def build(node, start, end): """构建线段树""" if start == end: # 叶子节点:存储单个元素 tree[node] = nums[start] return mid = (start + end) // 2 # 递归构建左右子树 build(node * 2, start, mid) # 左子树 build(node * 2 + 1, mid + 1, end) # 右子树 # 父节点存储区间信息(如最大值) tree[node] = max(tree[node * 2], tree[node * 2 + 1]) ``` 示例: ``` nums = [1, 3, 5, 7, 9, 11, 13, 15] build(1, 0, 7) # 从根节点开始构建 # 结果 tree 数组: # tree[1] = 15 # [0,7] 的最大值 # tree[2] = 7 # [0,3] 的最大值 # tree[3] = 15 # [4,7] 的最大值 # tree[4] = 3 # [0,1] 的最大值 # ... ``` 3. 区间查询 ``` def query(node, start, end, l, r): """查询区间 [l, r] 的最大值""" # 情况 1:查询区间与当前区间无交集 if r < start or end < l: return float('-inf') # 返回无效值 # 情况 2:查询区间完全包含当前区间 if l <= start and end <= r: return tree[node] # 直接返回节点值 # 情况 3:部分重叠,递归查询 mid = (start + end) // 2 left_max = query(node * 2, start, mid, l, r) right_max = query(node * 2 + 1, mid + 1, end, l, r) return max(left_max, right_max) ``` 示例: ``` # 查询 [2, 5] 的最大值 query(1, 0, 7, 2, 5) # 查询路径: # [0,7] → 部分重叠 → 查询 [0,3] 和 [4,7] # [0,3] → 部分重叠 → 查询 [2,3] # [4,7] → 部分重叠 → 查询 [4,5] # [2,3] → 完全包含 → 返回 7 # [4,5] → 完全包含 → 返回 11 # 结果: max(7, 11) = 11 ``` 4. 区间更新 ``` def update(node, start, end, idx, val): """更新单个元素 nums[idx] = val""" if start == end: # 找到叶子节点,更新值 tree[node] = val return mid = (start + end) // 2 if idx <= mid: # 在左子树 update(node * 2, start, mid, idx, val) else: # 在右子树 update(node * 2 + 1, mid + 1, end, idx, val) # 更新父节点 tree[node] = max(tree[node * 2], tree[node * 2 + 1]) ``` ### 完整示例 ```python class SegmentTree: def __init__(self, nums): self.n = len(nums) self.nums = nums self.tree = [0] * (4 * self.n) self.build(1, 0, self.n - 1) def build(self, node, start, end): if start == end: self.tree[node] = self.nums[start] return mid = (start + end) // 2 self.build(node * 2, start, mid) self.build(node * 2 + 1, mid + 1, end) self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1]) def query(self, l, r): """查询区间 [l, r] 的最大值""" return self._query(1, 0, self.n - 1, l, r) def _query(self, node, start, end, l, r): if r < start or end < l: return float('-inf') if l <= start and end <= r: return self.tree[node] mid = (start + end) // 2 left_max = self._query(node * 2, start, mid, l, r) right_max = self._query(node * 2 + 1, mid + 1, end, l, r) return max(left_max, right_max) def update(self, idx, val): """更新 nums[idx] = val""" self._update(1, 0, self.n - 1, idx, val) def _update(self, node, start, end, idx, val): if start == end: self.tree[node] = val return mid = (start + end) // 2 if idx <= mid: self._update(node * 2, start, mid, idx, val) else: self._update(node * 2 + 1, mid + 1, end, idx, val) self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1]) # 使用示例 nums = [1, 3, 5, 7, 9, 11, 13, 15] st = SegmentTree(nums) # 查询 [2, 5] 的最大值 print(st.query(2, 5)) # 输出: 11 # 更新 nums[3] = 20 st.update(3, 20) # 再次查询 [2, 5] 的最大值 print(st.query(2, 5)) # 输出: 20 ``` ### 在最大二叉树中的应用 ```python class Solution: def constructMaximumBinaryTree(self, nums): if not nums: return None n = len(nums) tree = [0] * (4 * n) # 构建线段树(存储最大值的索引) def build(node, start, end): if start == end: tree[node] = start return mid = (start + end) // 2 build(node * 2, start, mid) build(node * 2 + 1, mid + 1, end) # 存储较大值的索引 left_idx = tree[node * 2] right_idx = tree[node * 2 + 1] tree[node] = left_idx if nums[left_idx] > nums[right_idx] else right_idx # 查询区间最大值的索引 def query(node, start, end, l, r): if r < start or end < l: return -1 if l <= start and end <= r: return tree[node] mid = (start + end) // 2 left_idx = query(node * 2, start, mid, l, r) right_idx = query(node * 2 + 1, mid + 1, end, l, r) if left_idx == -1: return right_idx if right_idx == -1: return left_idx return left_idx if nums[left_idx] > nums[right_idx] else right_idx # 构建二叉树 def build_tree(start, end): if start >= end: return None root_idx = query(1, 0, n - 1, start, end - 1) root = TreeNode(nums[root_idx]) root.left = build_tree(start, root_idx) root.right = build_tree(root_idx + 1, end) return root build(1, 0, n - 1) return build_tree(0, n) ``` ``` 复杂度分析 操作 时间复杂度 空间复杂度 建树 O(n) O(n) 区间查询 O(log n) O(log n) 递归栈 单点更新 O(log n) O(log n) 递归栈 ``` ### 适用场景 适合: - 频繁的区间查询 - 频繁的区间更新 - 数组大小固定 不适合: - 数组频繁插入/删除 - 只需要单点查询(用数组即可) - 数据量很小(暴力法更快) 总结 - 线段树 = 区间操作的加速器 - 建树:O(n) 预处理 - 查询:O(log n) 快速查询 - 更新:O(log n) 快速更新 核心思想:将区间分解为 O(log n) 个子区间,利用预处理的信息快速计算。
评论 0

发表评论 取消回复

Shift+Enter 换行  ·  Enter 发送
还没有评论,来发表第一条吧