有序字典

我们之前讲过字典,并且尝试过用非排序列表和哈希表实现字典。
那么什么是有序字典呢?
就是字典中的key按照顺序排布
和collections中的有序字典还不一样,那个是保存了用户添加key-value对的顺序。

实现

这里我们采用一个列表来存储。
因为是有序列表,所以我们用二分法来进行查找,进而实现插入和删除、修改。

代码:

from collections import MutableMapping
class MapBase(MutableMapping):
    """lightweight composite to store key-value pairs as map items"""
    class _Item():
        __slots__ = '_key', '_value'
        def __init__(self,k,v):
            self._key = k
            self._value = v
        def __eq__(self,other):
            return self._key == other._key
        def __ne__(self,other):
            return not(self==other)
        def __lt__(self,other):
            return self._key < other._key

class SortedTableMap(MapBase):
    """Map implementation using a sorted table"""
    def _find_index(self,k,low,high):
        """return index of the leftmost item with key greater than or equal to k
        return high+1 if not find"""
        if high < low:
            return high+1
        else:
            mid = (low+high)//2
            if k == self._table[mid]._key:
                return mid
            elif k< self._table[mid]._key:
                return self._find_index(k,low,mid-1)
            else:
                return self._find_index(k,mid+1,high)
    def __init__(self):
        """create an empty map"""
        self._table = []
    def __len__(self):
        """return number of items in the map"""
        return len(self._table)
    def __getitem__(self,k):
        """return value associated with key k"""
        j = self._find_index(k,0,len(self._table)-1)
        if j == len(self._table) or self._table[j]._key != k:
            raise KeyError('Key Error: '+repr(k))
        return self._table[j]._value
    def __setitem__(self,k,v):
        """assign value v to key k, overwriting existing value if present"""
        j = self._find_index(k,0,len(self._table)-1)
        if j<len(self._table) and self._table[j]._key ==k:
            self._table[j]._value = v
        else:
            self._table.insert(j,self._Item(k,v))
    def __delitem__(self,k):
        """remove item associated with key k"""
        j = self._find_index(k,0,len(self._table)-1)
        if j==len(self._table) or self._table[j]._key != k:
            raise KeyError('key error: '+repr(k))
        self._table.pop(j)
    def __iter__(self):
        """generate keys of the map ordered from minimum to maximum"""
        for item in self._table:
            yield item._key
    def __reversed__(self):
        """generate keys of the map ordered from maximum to minimum"""
        for item in reversed(self._table):
            yield item._key
    def find_min(self):
        """return (key, value) pair with minimum key"""
        if len(self._table):
            return (self._table[0]._key,self._table[0]._value)
        else:
            return None
    def find_max(self):
        """return (key, value) pair with maximum key"""
        if len(self._table):
            return (self._table[-1]._key,self._table[-1]._value)
        else:
            return None
    def find_ge(self,k):
        """return (key, value) pair with least key greater than(or equal to) key"""
        j = self._find_index(k,0,len(self._table)-1)
        if j<len(self._table):
            return (self._table[j]._key,self._table[j]._value)
        else:
            return None
    def find_le(self,k):
        """return (key,value) pair with greatest key less than(or equal to) key"""
        j = self._find_index(k,0,len(self._table)-1)
        if self._table[j]._key == k:
            return (self._table[j].key,self._table[j]._value)
        elif j:
            return (self._table[j-1]._key,self._table[j-1]._value)
        else:
            return None
    def find_lt(self,k):
        """return (key,value) pair with greatest key strictly less than k"""
        j = self._find_index(k,0,len(self._table)-1)
        if j>0:
            return (self._table[j-1]._key,self._table[j-1]._value)
        else:
            return None
    def find_gt(self,k):
        """return (key,value) pair with least key strictly greater than k"""
        j = self._find_index(k,0,len(self._table)-1)
        if j<len(self._table) and self._table[j]._key == k:
            j += 1
        if j<len(self._table):
            return (self._table[j]._key,self._table[j]._value)
        else:
            return None
    def find_range(self,start,stop):
        """iterate all (key,value) pair such that start<=key<stop
        if start is none,iteration begins with minimum key of map
        if stop is none,iteration continues through the maximum key of map"""
        if start is None:
            j = 0
        else:
            j = self._find_index(start,0,len(self._table)-1)
        # j在范围内并且符合要求
        while j<len(self._table) and (stop is None or self._table[j]._key<stop):
            yield (self._table[j]._key, self._table[j]._value)
            j += 1

看着很多,但是我们来划分一下。
第一个提供了一个映射的基类,来存储一个key-value对,同时重构了一些常用方法(判断相等、比较),注意一下相等和不等的条件差了一点。

第二部分虽然看着是子类,其实是对第一部分的一个补充。
第一个非公有方法_find_index是寻找一个合适的位置。
为什么是说合适的,因为这个查找不但在引用、修改被使用,同时在插入也会使用(提供一个合适的插入位置),所以才这样说。

他那个加一就很狗,但确实是有效,二分法没得说,这里我来分析一下为什么加一能实现找到一个合适的插入位置呢?
首先,我们判断一下什么时候才能有high<low,初始不行(除非你带入错了),剩下的就是在mid位置的键小于我们的k时,传入low和mid-1,当mid[=(high+low)//2] -1 < low,只可能是low=high(-1)两种情况,具体来看一下。

low=high-1:

此时有mid = low,如果要出现low < high,要有mid的k值大于输入的k值,但mid=low,也就是区间范围内的元素都大于输入的键值,说明是没找到,当我们传入(low,mid-1),返回的是mid,也就是我们的low,刚好是我们要插入的位置。(选出这个区间,说明前面的都小于,区间都大于,只能插low了。

low=high

这个更好判断,mid=low,要是还大于,只能插入low,也就是(mid-1)+1。

其实这些直接判断也可以,但是个人不太行,二分写得少,所以来分析一下。

接下来的,是三个对元素的处理,只需要对着_find_index来操作即可。
对于返回的位置,我们判断一下是否存储了需要的键值,如果是就该咋咋地,如果不是,该插入插入,该报错报错。

生成器和最大最小值,自己看看得了。

然后是四个找大于、小于、大于等于、小于等于的,如果没有返回空。(比如我传入键是5,要找一个键小于等于5中最大的)。这个……,看着就很蠢。
大于等于和小于等于先找一下,找到了直接返回(是等于的情况),剩下的情况和大于、小于一样,就记住一点:
要么返回k的位置,要么返回k该插入的位置

最后的find_range是一个左闭右开区间,得到的方式看注释即可。