container/heap包使用指南

本文基于官方文档介绍 golang 标准库中提供的堆/优先队列的使用方法.

1. 概述

“container/heap” 包提供了实现堆操作的接口,用户只需要定义满足 “heap.Interface” 接口的类型,就可以通过包提供的函数,像操作大根堆或小根堆一样,对实例数组变量进行 PushPop 操作。
堆通常是一个可以被看做一棵树的数组对象,堆总是满足下列性质:

  • 堆中某个结点的值总是不大于或不小于其父结点的值;
  • 堆总是一棵完全二叉树。

堆的定义如下:n 个元素的序列 {k1,k2,ki,…,kn} 当且仅当满足下关系时,称之为堆:

xxrlF0.png

下面将基于具体实例介绍"container/heap" 包的使用。

2. 整数堆

下面的代码实现了一个整数类型的最小堆:

package main import ( "container/heap" "log" ) type IntHeap []int // 下面几个函数必须实现,heap包会进行回调 func (h IntHeap) Len() int { return len(h) } // Less函数的实现决定最终实现是最大堆还是最小堆 func (h IntHeap) Less(i, j int) bool { return h[i] < h[j] } func (h IntHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *IntHeap) Push(x any) { *h = append((*h), x.(int)) } func (h *IntHeap) Pop() any { old := *h n := len(old) x := old[n-1] *h = old[0 : n-1] return x } func main() { h := &IntHeap{2, 1, 5} heap.Init(h) heap.Push(h, 3) log.Println(*h) for h.Len() > 0 { log.Println(heap.Pop(h)) } }

3. 标准库堆的实现

通过整数堆的使用方式,发现是通过定义新的整数数组类型,并为其实现 pointer receivers 的方法 Push、Pop,以及 value receivers 的方法 Len、Less、Swap 方法之后,借助 “container/heap” 包提供的方法对该类型定义的几个方法进行回调,从而实现堆的功能。下面是具体的实现代码,逻辑见注释:

package heap import "sort" // Interface 接口指明了想要使用这个包中的方法去实现堆,应该提供的接口方法 // 在 heap.Init 方法被调用、数据为空或者原始数据有序时,满足下列条件的情况下,小根堆会被建立 // !h.Less(j, i) for 0 <= i < h.Len() and 2*i+1 <= j <= 2*i+2 and j < h.Len() // 注意:包中的方法会在添加和删除元素的时候调用Interface 接口实现的 Push 和 Pop方法,详见下面代码 type Interface interface { sort.Interface // 包括 Len()、Less(i, j int)、Swap(i, j int) Push(x any) // add x as element Len() Pop() any // remove and return element Len() - 1. } // 将传入的变量初始化为堆,时间复杂度为 O(n), 其中 n = h.Len(). func Init(h Interface) { // heapify n := h.Len() for i := n/2 - 1; i >= 0; i-- { down(h, i, n) } } // 向堆中加入新的元素,时间复杂度为 O(log n), 其中 n = h.Len(). func Push(h Interface, x any) { h.Push(x) up(h, h.Len()-1) } // Pop 移除并返回堆中的最小或最大元素,具体根据 h.Less 函数确定,时间复杂度为 O(log n). // Pop 等价于 Remove(h, 0). func Pop(h Interface) any { n := h.Len() - 1 h.Swap(0, n) down(h, 0, n) // 由此看出实现 h.Pop 方法时只需要将数组最后元素取出并返回即可 return h.Pop() } // Remove 移除并返回堆中索引为 i 的元素,时间复杂度为 O(log n). func Remove(h Interface, i int) any { n := h.Len() - 1 if n != i { h.Swap(i, n) if !down(h, i, n) { up(h, i) } } return h.Pop() } // 当堆数组中索引 i 处的元素的值或优先级发生变更的时候调用 Fix 调整元素 i 在堆中的位置 func Fix(h Interface, i int) { if !down(h, i, h.Len()) { up(h, i) } } // 向上进行堆调整,将新增元素上升到满足条件的位置 func up(h Interface, j int) { for { i := (j - 1) / 2 // parent if i == j || !h.Less(j, i) { break } h.Swap(i, j) j = i } } // 向下进行堆调整,确保i0节点是左右子树中的最小节点 func down(h Interface, i0, n int) bool { i := i0 for { j1 := 2*i + 1 if j1 >= n || j1 < 0 { // j1 < 0 after int overflow break } j := j1 // left child if j2 := j1 + 1; j2 < n && h.Less(j2, j1) { j = j2 // = 2*i + 2 // right child } if !h.Less(j, i) { break } h.Swap(i, j) i = j } return i > i0 }

4. 优先队列

举例实现针对结构实现的最小堆,也即一般意义上的优先队列:

package main import ( "container/heap" "log" ) type Node struct { Val int Next float32 } type NodeHeap []Node func (pq NodeHeap) Less(i, j int) bool { return pq[i].Val < pq[j].Val } func (pq NodeHeap) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } func (pq NodeHeap) Len() int { return len(pq) } func (pq *NodeHeap) Push(x any) { *pq = append(*pq, x.(Node)) } func (pq *NodeHeap) Pop() any { old := *pq n := len(old) x := old[n-1] *pq = old[0 : n-1] return x } func main() { pq := &NodeHeap{} heap.Init(pq) heap.Push(pq, Node{Val: 10, Next: 1.0}) heap.Push(pq, Node{Val: 11, Next: 2.0}) heap.Push(pq, Node{Val: 1, Next: 3.0}) for pq.Len() > 0 { log.Println(heap.Pop(pq).(Node)) } }

5. 优先队列的应用

力扣第 23 题「 合并 K 个升序链表」,要求合并 k 个有序链表为 1 个有序列表,如何快速得到 k 个节点中的最小节点,接到结果链表上?
此时就可以使用上述实现的优先队列了,不过需要稍微改动一下结构体,实现如下:

/** * Definition for singly-linked list. * type ListNode struct { * Val int * Next *ListNode * } */ type ItemHeap []*ListNode func (pq ItemHeap) Less(i, j int) bool { return pq[i].Val < pq[j].Val} func (pq ItemHeap) Swap(i, j int) {pq[i], pq[j] = pq[j], pq[i]} func (pq ItemHeap) Len() int {return len(pq)} func (pq *ItemHeap) Push(x interface{}) { *pq = append(*pq, x.(*ListNode)) } func (pq *ItemHeap) Pop() interface{} { old := *pq n := len(old) x := old[n-1] *pq = old[0:n-1] return x } func mergeKLists(lists []*ListNode) *ListNode { pq := &ItemHeap{} heap.Init(pq) for _, v := range lists { if v != nil { heap.Push(pq, v) } } // dummy dummy := &ListNode{} p := dummy for pq.Len() > 0 { x := heap.Pop(pq).(*ListNode) p.Next = x p = p.Next if x.Next != nil { heap.Push(pq, x.Next) } } return dummy.Next }

6. 参考资料