模版题-数据结构-线段树
## 线段树详解
### 核心概念
线段树是一种二叉树数据结构,用于高效处理区间查询和区间更新操作。
### 基本结构
```
区间 [0, 7] 的线段树:
[0,7] (根节点)
/ \
[0,3] [4,7]
/ \ / \
[0,1] [2,3] [4,5] [6,7]
/ \ / \ / \ / \
[0][1][2][3][4][5][6][7]
```
特点:
- 每个节点代表一个区间
- 叶子节点代表单个元素
- 父节点区间 = 左子节点 + 右子节点
### 为什么需要线段树
**问题场景**
给定数组 nums = [1, 3, 5, 7, 9, 11, 13, 15],需要频繁执行:
- 区间查询:求 [2, 5] 的最大值
- 区间更新:将 [3, 6] 的元素都加 10
```
暴力方法 vs 线段树
操作 暴力方法 线段树
查询 [2, 5] 最大值 O(n) 遍历 O(log n)
更新 [3, 6] 元素 O(n) 遍历 O(log n)
预处理 O(1) O(n)
```
### 线段树实现
1. 树的存储
```
# 使用数组存储(完全二叉树)
tree = [0] * (4 * n) # 4n 空间足够
# 节点索引关系:
# - 父节点: i // 2
# - 左子节点: i * 2
# - 右子节点: i * 2 + 1
```
2. 建树
```
def build(node, start, end):
"""构建线段树"""
if start == end:
# 叶子节点:存储单个元素
tree[node] = nums[start]
return
mid = (start + end) // 2
# 递归构建左右子树
build(node * 2, start, mid) # 左子树
build(node * 2 + 1, mid + 1, end) # 右子树
# 父节点存储区间信息(如最大值)
tree[node] = max(tree[node * 2], tree[node * 2 + 1])
```
示例:
```
nums = [1, 3, 5, 7, 9, 11, 13, 15]
build(1, 0, 7) # 从根节点开始构建
# 结果 tree 数组:
# tree[1] = 15 # [0,7] 的最大值
# tree[2] = 7 # [0,3] 的最大值
# tree[3] = 15 # [4,7] 的最大值
# tree[4] = 3 # [0,1] 的最大值
# ...
```
3. 区间查询
```
def query(node, start, end, l, r):
"""查询区间 [l, r] 的最大值"""
# 情况 1:查询区间与当前区间无交集
if r < start or end < l:
return float('-inf') # 返回无效值
# 情况 2:查询区间完全包含当前区间
if l <= start and end <= r:
return tree[node] # 直接返回节点值
# 情况 3:部分重叠,递归查询
mid = (start + end) // 2
left_max = query(node * 2, start, mid, l, r)
right_max = query(node * 2 + 1, mid + 1, end, l, r)
return max(left_max, right_max)
```
示例:
```
# 查询 [2, 5] 的最大值
query(1, 0, 7, 2, 5)
# 查询路径:
# [0,7] → 部分重叠 → 查询 [0,3] 和 [4,7]
# [0,3] → 部分重叠 → 查询 [2,3]
# [4,7] → 部分重叠 → 查询 [4,5]
# [2,3] → 完全包含 → 返回 7
# [4,5] → 完全包含 → 返回 11
# 结果: max(7, 11) = 11
```
4. 区间更新
```
def update(node, start, end, idx, val):
"""更新单个元素 nums[idx] = val"""
if start == end:
# 找到叶子节点,更新值
tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
# 在左子树
update(node * 2, start, mid, idx, val)
else:
# 在右子树
update(node * 2 + 1, mid + 1, end, idx, val)
# 更新父节点
tree[node] = max(tree[node * 2], tree[node * 2 + 1])
```
### 完整示例
```python
class SegmentTree:
def __init__(self, nums):
self.n = len(nums)
self.nums = nums
self.tree = [0] * (4 * self.n)
self.build(1, 0, self.n - 1)
def build(self, node, start, end):
if start == end:
self.tree[node] = self.nums[start]
return
mid = (start + end) // 2
self.build(node * 2, start, mid)
self.build(node * 2 + 1, mid + 1, end)
self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1])
def query(self, l, r):
"""查询区间 [l, r] 的最大值"""
return self._query(1, 0, self.n - 1, l, r)
def _query(self, node, start, end, l, r):
if r < start or end < l:
return float('-inf')
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
left_max = self._query(node * 2, start, mid, l, r)
right_max = self._query(node * 2 + 1, mid + 1, end, l, r)
return max(left_max, right_max)
def update(self, idx, val):
"""更新 nums[idx] = val"""
self._update(1, 0, self.n - 1, idx, val)
def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
return
mid = (start + end) // 2
if idx <= mid:
self._update(node * 2, start, mid, idx, val)
else:
self._update(node * 2 + 1, mid + 1, end, idx, val)
self.tree[node] = max(self.tree[node * 2], self.tree[node * 2 + 1])
# 使用示例
nums = [1, 3, 5, 7, 9, 11, 13, 15]
st = SegmentTree(nums)
# 查询 [2, 5] 的最大值
print(st.query(2, 5)) # 输出: 11
# 更新 nums[3] = 20
st.update(3, 20)
# 再次查询 [2, 5] 的最大值
print(st.query(2, 5)) # 输出: 20
```
### 在最大二叉树中的应用
```python
class Solution:
def constructMaximumBinaryTree(self, nums):
if not nums:
return None
n = len(nums)
tree = [0] * (4 * n)
# 构建线段树(存储最大值的索引)
def build(node, start, end):
if start == end:
tree[node] = start
return
mid = (start + end) // 2
build(node * 2, start, mid)
build(node * 2 + 1, mid + 1, end)
# 存储较大值的索引
left_idx = tree[node * 2]
right_idx = tree[node * 2 + 1]
tree[node] = left_idx if nums[left_idx] > nums[right_idx] else right_idx
# 查询区间最大值的索引
def query(node, start, end, l, r):
if r < start or end < l:
return -1
if l <= start and end <= r:
return tree[node]
mid = (start + end) // 2
left_idx = query(node * 2, start, mid, l, r)
right_idx = query(node * 2 + 1, mid + 1, end, l, r)
if left_idx == -1:
return right_idx
if right_idx == -1:
return left_idx
return left_idx if nums[left_idx] > nums[right_idx] else right_idx
# 构建二叉树
def build_tree(start, end):
if start >= end:
return None
root_idx = query(1, 0, n - 1, start, end - 1)
root = TreeNode(nums[root_idx])
root.left = build_tree(start, root_idx)
root.right = build_tree(root_idx + 1, end)
return root
build(1, 0, n - 1)
return build_tree(0, n)
```
```
复杂度分析
操作 时间复杂度 空间复杂度
建树 O(n) O(n)
区间查询 O(log n) O(log n) 递归栈
单点更新 O(log n) O(log n) 递归栈
```
### 适用场景
适合:
- 频繁的区间查询
- 频繁的区间更新
- 数组大小固定
不适合:
- 数组频繁插入/删除
- 只需要单点查询(用数组即可)
- 数据量很小(暴力法更快)
总结
- 线段树 = 区间操作的加速器
- 建树:O(n) 预处理
- 查询:O(log n) 快速查询
- 更新:O(log n) 快速更新
核心思想:将区间分解为 O(log n) 个子区间,利用预处理的信息快速计算。