题目链接:432. 全 O(1) 的数据结构
使用一个循环的双向链表,从根结点遍历,顺序从最小值开始,最大值结束。
- 递增操作是将该 key 从
cnt对应的结点,挪到cnt + 1对应的结点,若没有目标结点,则创建,若源结点变为空,则删除该结点,该操作的时间复杂度是 $O(1)$; - 递减操作是将该 key 从
cnt对应的结点,挪到cnt - 1对应的结点,若没有目标结点,则创建,若源结点变为空,则删除该结点,该操作的时间复杂度是 $O(1)$; - 根结点后的第一个元素就是最小值,该操作的时间复杂度是 $O(1)$;
- 根结点前的第一个元素就是最大值,该操作的时间复杂度是 $O(1)$。
需要处理很多边界条件,比如 cnt 减小为 $0$,每个结点对应多个 key,其列表被删除为空等。
class Node:
def __init__(self, key: str, cnt: int):
self.keys, self.cnt, self.prev, self.next = {key}, cnt, self, self
def insert_next(self, node: 'Node'):
self.next.prev, self.next, node.next, node.prev = node, node, self.next, self,
def insert_prev(self, node: 'Node'):
self.prev.next, self.prev, node.prev, node.next, = node, node, self.prev, self,
def remove(self):
self.prev.next, self.next.prev = self.next, self.prev
class AllOne:
def __init__(self):
self.root, self.nodes = Node('', 0), dict()
def inc(self, key: str) -> None:
if key in self.nodes:
node, cnt = self.nodes[key], self.nodes[key].cnt + 1
node.keys.remove(key)
else:
node, cnt = self.root, 1
if node.next.cnt == cnt:
node.next.keys.add(key)
self.nodes[key] = node.next
else:
new = Node(key, cnt)
node.insert_next(new)
self.nodes[key] = new
if not node.keys:
node.remove()
def dec(self, key: str) -> None:
if key in self.nodes:
node, cnt = self.nodes[key], self.nodes[key].cnt - 1
node.keys.remove(key)
if 0 < cnt == node.prev.cnt:
node.prev.keys.add(key)
self.nodes[key] = node.prev
elif cnt > 0:
new = Node(key, cnt)
node.insert_prev(new)
self.nodes[key] = new
else:
del self.nodes[key]
if not node.keys:
node.remove()
def getMaxKey(self) -> str:
return next(iter(self.root.prev.keys))
def getMinKey(self) -> str:
return next(iter(self.root.next.keys))