跳錶 - 簡明教程 in Python
# 1. 什麼是跳錶
跳錶(Skip List)是基於連結串列 + 隨機化實現的一個有序資料結構,可以達到平均 O(logN) 的查詢、插入、刪除效率,在實際執行中的效率往往超過 AVL 等平衡二叉樹,而且其實現相對更簡單、記憶體消耗更低。
Redis 的 ZSET 底層實現就是用的 Skip List,這裡是 [Antirez對此的說明](ofollow,noindex">https://news.ycombinator.com/item?id=1171423) 。
這是一個典型的跳錶:
[0] -> 0 -> 1 -> 3 -> 4 -> 5 -> 6 -> 7 -> 9 -> nil [1] -> 0 ------> 3 ------> 5 ------> 7 ------> nil [2]----------------------> 5-----------------> nil
解釋一下:
1. SkipList 是一個多層的連結串列
2. 第[0]層的連結串列包含所有節點,其他層的連結串列包含部分節點,層次越高,節點越少
3. 每層連結串列之間會共享相同的節點(節省記憶體,但為了方便展示,每一層都輸出了它的值)
4. 對於某個節點,在插入時通過概率判斷它最高會出現在哪一層,並且也會出現在之下的每一層
通過這樣的設計,當需要查詢某個 key 時,可以從最高層的連結串列開始往前找,在這一層遇到末尾或者大於 key 的節點時往下走一個層,直到找到 key 節點。
例如:
引用
4 的查詢路徑為 [2] -> [1] -> 0 -> 3 -> 3@[0] -> 4
6 的查詢路徑為 [2] -> 5 -> 5@[1] -> 5@[0] -> 6
8 的查詢路徑為 [2] -> 5 -> 5@[1] -> 7 -> 7@[0] -> 9 (找不到)
# 2. 跳錶的節點
從上面的描述,我們大概可以知道 (1) 每個節點需要儲存一個 key; (2) 每個節點需要有多個next指標 (3) 其 next 指標的數量會在插入時確定
因此我們可以用下面這個 class 來表示節點:
class Node(object) def __init__(self, height, key): self.key = key self.next = [None] * height def height(self): return self.height()
# 3. 建立跳錶
一個新建立的跳錶是沒有節點的。但為了實現的簡單起見,可以新增一個頭節點:
class SkipList(object): def __init__(self): self.head = Node(0, None) #頭節點高度為0,不需要key
到目前為止都特別簡單,但是還什麼也幹不了。
# 4. 建立節點
建立節點時,需要先按一定的概率分佈確定其高度。
為了保證高層的節點比低層少,我們可以用這樣的概率分佈:
引用
Height(n) = p^n
實現其實非常簡單:
import random def randomHeight(self, p = 0.5): height = 1 while random.uniform(0, 1) < p and self.head.height() >= height: height += 1 return height
這樣可以保證平均的路徑長度是 log(n) 。
精確一點的話,實際上是 log(n-1, 1/p) / p,也就是說, p 的選擇會影響跳錶層數、平均路徑長度。
具體的計算比較複雜,有興趣可以參考跳錶的原論文《Skip Lists: A Probabilistic Alternative to Balanced Trees》。(TL;DR)
然後我們就可以這樣來建立一個新的節點:
node = Node(self.randomHeight(), key)
# 5. 新增節點
如果只是為空跳錶新增一個新的節點,只要更新頭結點的每一個next指標:
def insertFirstNode(self, key): node = Node(self.randomHeight(), key) while node.height > self.head.height(): self.head.next.append(None) #保證頭節點的next陣列覆蓋所有層次的連結串列 for level in range(node.height()): node.next[level] = self.head.next[level] self.head.next[level] = node
但很顯然這個方法只能用一次。
如果跳錶中已經有多個節點,那我們就必須找到每一層中適合插入的位置:
def getUpdateList(self, key): update = [None] * self.head.height() for level in range(len(update)): x = self.head while x.next[level] is not None and x.next[level].key < key: x = x.next[level] update[level] = x return update
這個函式返回一個 update 節點陣列,其中的每個節點都是在這一層中小於 key 的最後一個節點。
也就是說,在 level = i 層,總是可以把新的節點插入 update[i] 之後:
def insert(self, key): node = Node(self.randomHeight(), key) while node.height > self.head.height(): self.head.next.append(None) #保證頭節點的next陣列覆蓋所有層次的連結串列 update = self.getUpdateList(key) next0 = update[0].next[0] if next0 is not None and next0.key == key: return # 0層總是包含所有元素;如果 update[0] 的下一個節點與key相等,則無需插入。 for level in range(node.height()): node.next[level] = update[level].next[level] update[level].next[level] = node
但是由於這一版 getUpdateList 是 O(n) 的,插入效率並沒有達到跳錶的設計目標。
# 6. 新增節點++
考慮這一點:跳錶的每一層都是有序的。
也就是說,我們在找到 update[n] = x 以後,其實可以從節點 x 的 n - 1 層繼續查詢來查詢 update[n-1] 。
由於查詢路徑的評價長度是 log(N) ,所以我們可以實現一個更快的 getUpdateList 方法
注意,需要從最高層開始查
def getUpdateList(self, key): update = [None] * self.head.height() x = self.head for level in reversed(range(len(update))): while x.next[level] is not None and x.next[level].key < key: x = x.next[level] update[level] = x return update
# 7. 里程碑1
把上面的程式碼整合起來,我們就可以得到第一版跳錶程式碼:能夠插入節點。
為了更好地展示我們的成果,我們可以用這樣一個函式,把連結串列按第1節的例子樣式輸出:
def dump(self): for i in range(self.head.height()): sys.stdout.write('[H]') x = self.head.next[0] y = self.head.next[i] while x is not None: s = ' -> %s' % x.key if x is y: y = y.next[i] else: s = '-' * len(s) x = x.next[0] sys.stdout.write(s) print ' -> <nil>' print
試試看:
sl = SkipList() for i in range(10): sl.insert(sl) s1.dump()
[H] -> 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9 -> <nil> [H]----- -> 1 -> 2 -> 3---------- -> 6 -> 7---------- -> <nil> [H]---------- -> 2-------------------- -> 7---------- -> <nil>
多嘗試幾次,以及選擇不同的 p 值,可以觀察生成跳錶的區別。
# 8. 查詢節點
實際上查詢節點的過程,已經包含在 insert 的實現裡了:
def find(self, key): update = self.getUpdateList(key) if len(update) == 0: return None next0 = update[0].next[0] if next0 is not None and next0.key == key: return next0 # 0層總是包含所有元素;如果 update[0] 的下一個節點與key相等,則無需插入。 else: return None
# 9. 刪除節點
既然已經能找出 update 節點陣列,在 level = i 層,只要判斷 update[i].next[i] 是否等於要刪除的 key 就可以了:
def remove(self, key): update = self.getUpdateList(key) for i, node in enumerate(update): if node.next[i] is not None and node.next[i].key == key: node.next[i] = node.next[i].next[i]
# 10. 里程碑2
整合 find 和 update 陣列,就可以實現跳錶的基礎操作了,試試看:
node = sl.find(3) print node for i in range(7, 14): sl.remove(i) sl.dump()
# 11. 其他
我們在 Node 中只添加了一個 key 屬性,在具體的實現中,我們往往可能需要針對 key 儲存一個 value,例如 Python 自帶的 dict 實現。改造起來也很簡單:
1. node 中新增一個 value 屬性,並且新增相應的初始化邏輯(__init__方法)
2. 將 SkipList.insert 修改為 `insert(self, key, value)`,在新建 Node 時指定其 value
3. 再新增一個 `update(self, key, value)` API,方便呼叫方的使用
4. 可以考慮針對語言適配,例如實現 python 的 __getitem__ 、 __setitem__ 等魔術方法
# 12. 完整程式碼
#coding:utf-8 import random class Node(object): def __init__(self, height, key=None): self.key = key self.next = [None] * height def height(self): return len(self.next) class SkipList(object): def __init__(self): self.head = Node(0, None) #頭節點高度為0,不需要key def randomHeight(self, p = 0.5): height = 1 while random.uniform(0, 1) < p and self.head.height() >= height: height += 1 return height def insert(self, key): node = Node(self.randomHeight(), key) print node.height(), node.key while node.height() > self.head.height(): self.head.next.append(None) #保證頭節點的next陣列覆蓋所有層次的連結串列 update = self.getUpdateList(key) if update[0].next[0] is not None and update[0].next[0].key == key: return # 0層總是包含所有元素;如果 update[0] 的下一個節點與key相等,則無需插入。 for level in range(node.height()): node.next[level] = update[level].next[level] update[level].next[level] = node def getUpdateList(self, key): update = [None] * self.head.height() x = self.head for level in reversed(range(len(update))): while x.next[level] is not None and x.next[level].key < key: x = x.next[level] update[level] = x return update def dump(self): for i in range(self.head.height()): sys.stdout.write('[H]') x = self.head.next[0] y = self.head.next[i] while x is not None: s = ' -> %s' % x.key if x is y: y = y.next[i] else: s = '-' * len(s) x = x.next[0] sys.stdout.write(s) print ' -> <nil>' print def find(self, key): update = self.getUpdateList(key) if len(update) == 0: return None next0 = update[0].next[0] if next0 is not None and next0.key == key: return next0 # 0層總是包含所有元素;如果 update[0] 的下一個節點與key相等,則無需插入。 else: return None def remove(self, key): update = self.getUpdateList(key) for i, node in enumerate(update): if node.next[i] is not None and node.next[i].key == key: node.next[i] = node.next[i].next[i]
完。
轉載請註明出自,如是轉載文則註明原出處,謝謝:)