算法_大小根堆的使用

Python 中的堆(heapq)使用详解

Python 中的堆是通过 heapq 模块实现的,默认是小顶堆(最小堆)。


一、堆的基本操作

1. 导入和创建堆

1
2
3
4
5
6
7
8
9
import heapq

# 创建一个空堆
heap = []

# 或者将现有列表转为堆
nums = [3, 1, 4, 1, 5, 9, 2]
heapq.heapify(nums) # 原地转换,O(n)
print(nums) # [1, 1, 2, 3, 5, 9, 4](小顶堆结构)

2. 基本操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import heapq

heap = []

# 添加元素
heapq.heappush(heap, 5)
heapq.heappush(heap, 2)
heapq.heappush(heap, 3)
heapq.heappush(heap, 1)
print(heap) # [1, 2, 3, 5](堆序结构)

# 弹出最小元素
smallest = heapq.heappop(heap) # 返回 1
print(heap) # [2, 5, 3]

# 查看最小元素(不弹出)
print(heap[0]) # 2

# 弹出并推入新元素(比分别调用pop和push更高效)
result = heapq.heapreplace(heap, 4) # 弹出2,推入4
print(result) # 2
print(heap) # [3, 5, 4]

# 先推入再弹出
result = heapq.heappushpop(heap, 0) # 推入0,弹出0(因为0最小)
print(result) # 0
print(heap) # [3, 5, 4]

3. 复杂元素(元组、对象)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import heapq

# 使用元组实现优先队列(优先级, 元素)
heap = []
heapq.heappush(heap, (2, "任务B"))
heapq.heappush(heap, (1, "任务A"))
heapq.heappush(heap, (3, "任务C"))

priority, task = heapq.heappop(heap)
print(f"优先级{priority}: {task}") # 优先级1: 任务A

# 存储自定义对象(需要实现 __lt__ 方法)
class Task:
def __init__(self, priority, name):
self.priority = priority
self.name = name

def __lt__(self, other):
return self.priority < other.priority

def __repr__(self):
return f"{self.name}({self.priority})"

heap = []
heapq.heappush(heap, Task(3, "任务C"))
heapq.heappush(heap, Task(1, "任务A"))
heapq.heappush(heap, Task(2, "任务B"))

while heap:
print(heapq.heappop(heap)) # 按优先级输出

二、常用函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import heapq

data = [3, 1, 4, 1, 5, 9, 2]

# 获取最大的3个元素
largest3 = heapq.nlargest(3, data)
print(largest3) # [9, 5, 4]

# 获取最小的3个元素
smallest3 = heapq.nsmallest(3, data)
print(smallest3) # [1, 1, 2]

# 可以指定key(类似sorted)
people = [
{'name': 'Alice', 'age': 30},
{'name': 'Bob', 'age': 25},
{'name': 'Charlie', 'age': 35}
]
youngest = heapq.nsmallest(2, people, key=lambda x: x['age'])
print(youngest) # [{'name': 'Bob', 'age': 25}, {'name': 'Alice', 'age': 30}]

三、常见题型

题型1:Top K 问题

例1:数组中的第K个最大元素(LeetCode 215)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import heapq
from typing import List

class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
# 维护一个大小为k的小顶堆
heap = []
for num in nums:
heapq.heappush(heap, num)
if len(heap) > k:
heapq.heappop(heap) # 弹出最小的
# 堆顶就是第k大的元素
return heap[0]

# 更简洁的写法
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
return heapq.nlargest(k, nums)[-1]

例2:前K个高频元素(LeetCode 347)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import heapq
from collections import Counter
from typing import List

class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
count = Counter(nums)

# 小顶堆,存储 (频率, 元素)
heap = []
for num, freq in count.items():
heapq.heappush(heap, (freq, num))
if len(heap) > k:
heapq.heappop(heap)

return [num for freq, num in heap]

# 用nlargest更简洁
class Solution:
def topKFrequent(self, nums: List[int], k: int) -> List[int]:
count = Counter(nums)
return [num for num, _ in heapq.nlargest(k, count.items(), key=lambda x: x[1])]

题型2:多路归并

例3:合并K个升序链表(LeetCode 23)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import heapq
from typing import List, Optional

# Definition for singly-linked list.
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next

class Solution:
def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
# 创建虚拟头节点
dummy = ListNode(0)
curr = dummy

# 堆中存储 (节点值, 索引, 节点)
# 用索引避免节点值相同时的比较问题
heap = []
for i, node in enumerate(lists):
if node:
heapq.heappush(heap, (node.val, i, node))

while heap:
val, i, node = heapq.heappop(heap)
curr.next = node
curr = curr.next

if node.next:
heapq.heappush(heap, (node.next.val, i, node.next))

return dummy.next

题型3:数据流中的中位数

例4:数据流的中位数(LeetCode 295)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import heapq

class MedianFinder:
def __init__(self):
# 小顶堆存较大的一半,大顶堆存较小的一半
# Python没有大顶堆,用负数模拟
self.small = [] # 大顶堆(存负数)
self.large = [] # 小顶堆

def addNum(self, num: int) -> None:
# 先加入small
heapq.heappush(self.small, -num)

# 保证small的所有元素 <= large的所有元素
if self.small and self.large and (-self.small[0] > self.large[0]):
val = -heapq.heappop(self.small)
heapq.heappush(self.large, val)

# 平衡两个堆的大小
if len(self.small) > len(self.large) + 1:
val = -heapq.heappop(self.small)
heapq.heappush(self.large, val)
elif len(self.large) > len(self.small) + 1:
val = heapq.heappop(self.large)
heapq.heappush(self.small, -val)

def findMedian(self) -> float:
if len(self.small) > len(self.large):
return -self.small[0]
elif len(self.large) > len(self.small):
return self.large[0]
else:
return (-self.small[0] + self.large[0]) / 2

题型4:任务调度

例5:任务调度器(LeetCode 621)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import heapq
from collections import Counter, deque
from typing import List

class Solution:
def leastInterval(self, tasks: List[str], n: int) -> int:
count = Counter(tasks)

# 大顶堆(用负数模拟)
max_heap = [-cnt for cnt in count.values()]
heapq.heapify(max_heap)

time = 0
queue = deque() # (等待结束时间, 剩余次数)

while max_heap or queue:
time += 1

if max_heap:
cnt = heapq.heappop(max_heap) + 1 # 执行一个任务
if cnt < 0: # 还有剩余任务
queue.append((time + n, cnt))

if queue and queue[0][0] == time:
_, cnt = queue.popleft()
heapq.heappush(max_heap, cnt)

return time

题型5:最小开销问题

例6:连接棍子的最小开销(LeetCode 1167)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import heapq
from typing import List

class Solution:
def connectSticks(self, sticks: List[int]) -> int:
heapq.heapify(sticks)
total_cost = 0

while len(sticks) > 1:
# 每次取最小的两个
s1 = heapq.heappop(sticks)
s2 = heapq.heappop(sticks)
cost = s1 + s2
total_cost += cost
heapq.heappush(sticks, cost)

return total_cost

题型6:K个有序数组的交集

例7:Smallest Range Covering Elements from K Lists(LeetCode 632)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import heapq
from typing import List

class Solution:
def smallestRange(self, nums: List[List[int]]) -> List[int]:
heap = []
max_val = float('-inf')

# 初始化:每个数组的第一个元素入堆
for i in range(len(nums)):
heapq.heappush(heap, (nums[i][0], i, 0))
max_val = max(max_val, nums[i][0])

range_start, range_end = -10**5, 10**5

while heap:
min_val, row, col = heapq.heappop(heap)

# 更新最小范围
if max_val - min_val < range_end - range_start:
range_start, range_end = min_val, max_val

# 如果当前数组还有下一个元素
if col + 1 < len(nums[row]):
next_val = nums[row][col + 1]
heapq.heappush(heap, (next_val, row, col + 1))
max_val = max(max_val, next_val)
else:
# 有一个数组遍历完了,结束
break

return [range_start, range_end]

四、堆 vs 其他数据结构

操作 排序数组 平衡树
插入 O(log n) O(n) O(log n)
删除最小 O(log n) O(1) O(log n)
查看最小 O(1) O(1) O(log n)
随机访问 O(n) O(1) O(log n)
空间 O(n) O(n) O(n)

五、使用技巧

  1. 大顶堆模拟:存负数
  2. 复杂排序:存元组 (优先级, 索引, 值)
  3. Top K问题:固定堆大小
  4. 动态数据流:双堆技巧
  5. 懒删除:用字典标记已删除元素

六、时间复杂度总结

  • heapify(): O(n)
  • heappush(): O(log n)
  • heappop(): O(log n)
  • heapreplace(): O(log n)
  • heappushpop(): O(log n)
  • nlargest(k): O(n log k)