二叉查找树(Binary Search Tree),又称为二叉搜索树、二叉排序树。

链表插入数据很快,查询慢,数组查询快,插入慢,而二叉查找树则两者都比较快。

无特征的树结构基本没什么用。而叉查找树是一种有树结构有特征的结构,能够做到插入和查询的相对快速。

这篇文章是关于python二叉查找树的实现,这里会涉及以下几个部分

Node class

Insert method 插入方法
Lookup method 查找方法
Delete method 删除方法
Print method 打印方法
Comparing 2 trees 比较2个树
Generator returning the tree elemeents one by one 一个一个的返回元素

这是关于二分查找树(BST)的定义:

左子树的节点元素小于右子树的节点元素

右子树的节点元素大于右子树节点

左子树右子树也是二分查找树

这是一个BST的树结构


bst.png

Node class

要代表一个树结构,我们需要创建一个类包含三个属性

Left node
Right node
Node's data
class Node:
def __init__(self, data):
self.left = None
self.right = None
self.data = data

创建root节点:

root = Node(8)
Insert method
class Node:
......
def insert(self, data):
if self.data is None:
self.data = data
return
if data < self.data:
if self.left is None:
self.left = Node(data)
else:
self.left.insert(data) # 递归调用
elif data > self.data: # data 大于当前节点 应该放在右边
if self.right is None:
self.right = Node(data)
else:
self.right.insert(data)

insret()会递归的调用,用来正确的添加新的node在正确的树结构里面。

现在添加三个节点到跟节点,看下代码是什么

root.insert(3)

root.insert(10)

root.insert(1)

这是当添加了第二个节点(第一个是root节点,8) node(3)发生的事情:

root节点调用了insert(),插入一个data=3的数据

3小于8,左边的树为None,所以把3附加在了root节点的左边。

这是当添加了第三个节点node(10)发生的事情:

root节点的方法调用了insert方法,传递参数data=10

10大于8 并且右子树为None,所以把8添加到root节点的右节点

这是当添加了第四个节点node(1)发生的事情:

root 的节点的insert方法调用,传递参数data=1

1小于3,所以会root‘s 左child(3)的insert方法会调用,并且传递参数data=1

1 < 3,并且node(3)的左child为None, 所以node(1)添加在node(3)的左边。

现在这个二分查找树的结构像这样:


bst_insert.png

现在继续添加几个Node,开始下一部分关于Node节点的查找。

root.insert(6)

root.insert(4)

root.insert(7)

root.insert(14)

root.insert(13)


bst.png

Lookup method

我们需要一种方法在二叉查找树上查找指定的一个数,我们添加一个新的方法叫lookup,需要一个node节点的数据作为参数,返回None如果没有找到,返回节点已经父节点如果找到该数据。

class Node:
......
def lookup(self, data, parent=None):
if data < self.data:
if self.left is None:
return None, None
return self.left.lookup(data, self)
elif data > self.data:
if self.right is None:
return None, None
return self.right.lookup(data, self)
else:
return self, parent

现在查找node(6)试一试

node, parent = root.lookup(6)

以下流程解释了当lookup()函数调用之后的结果:

root's 的lookup() 函数调用了,data参数是6, parent是默认值为None

data=6小于root’s节点 8

root's 的左边的child的lookup()调用,传入data=6,parent=当前节点(node(3))

data=6 大于 data=3

node(3)的右边的child的lookup()函数调用,传入data=6,parent=当前节点

node's 的data数据相等,返回当前节点和父节点


bst_lookup.png

Delete method

delete()方法需要一个被删除节点的参数

class Node:
......
def delete(self, data):
node, parent = self.lookup(data)
if node is not None:
children_count = node.children_count()

这里有三种可能需要处理

要删除的节点没有child

要删除的节点有一个child

要删除的节点有两个child

处理第一种情况挺容易的,我们找到要删除的数据的节点,设置它的左child或者右child为None,如果是root节点,清除它的数据。

def delete(self, data):
node, parent = self.lookup(data)
if node is not None:
children_count = node.children_count()
if children_count == 0:
if parent:
if parent.left is node: # 判断是否是parent的左节点
parent.left = None
else:
parent.right = None
del node
else:
self.data = None # root node
Note: children_count()返回节点的child数量

这是函数children_count()的实现:

class Node():
......
def children_count(self):
cnt = 0
if self.left:
cnt += 1
if self.right:
cnt += 1
return cnt

举一个例子,我们想删除node(1), Node(3)的左child会设置为None

root.delete(1)

bst_delete_0.png

现在,让我们处理第二种情况,删除的节点有一个child,我们将节点替换为它的child,当是root节点的时候,还做了特殊处理。

elif children_count == 1:
if node.left:
n = node.left
else:
n = node.right
if parent:
if parent.left is node: # 判断是否是parent的左节点
parent.left = n
else:
parent.right = n
del node
else: # root node
self.left = n.left
self.right = n.right
self.data = n.data

例如,我们想删除node(14), Node 14 的data会设置为13(它的左child's data) ,并且它的左child会设置为None

root.delete(14)


bst_delete_1.png

让我们看最后一种可能,删除的节点有2个children,替换数据为继承者的数据,然后修复继承者的parent‘s的child

else:
# if node has 2 children
# 找到它的继承者
parent = node
successor = node.right
while successor.left:
parent = successor
successor = successor.left
# 替换接点数据为继承者(子节点)数据
node.data = successor.data
if parent.left == successor:
parent.left = successor.right
else:
parent.right = successor.right

例如, 我们想删除node 3, 我们正确的查找到了叶子节点(最后一个节点)才离开循环,把node 3 替换为了node 4,node 4没有child,所以把node 6作为右child。

root.delete(3)


bst_delete_2.png


mytest.png

Print method

我们添加了一个方法打印树的中序(左中右),这个方法不需要参数。我们在print_tree()里面使用递归来走到树的最深处,先贯穿整个左子树,再打印root节点,然后贯穿右子树。

class Node:
......
def print_tree(self):
if self.left:
self.left.print_tree()
print(self.data, end=' ')
if self.right:
self.right.print_tree()

输出的结果是: 1, 3, 4, 6, 7, 8, 10, 13, 14

Comparing 2 trees

为了比较2个树,我们添加一个方法用来递归的比较每一个子树,当有叶子节点不在相同的两个树中,返回False,包括叶子节点丢失和data不一样。需要传入一个root节点的树来作为比较的参数。

def compare_trees(self, node):
if node is None: # 传入的节点为None
return False
if self.data != node.data: # 当前节点的数据和其它树的节点的数据不同
return False
res = True
if self.left is None:
if node.left: # 当前节点没有右节点 但是其它树节点有
return False
else:
res = self.left.compare_trees(node.left) # 比较下一个左节点的数据
if res is False:
return False
if self.right is None:
if node.right: # 逻辑同上 现在遍历右边的节点
return False
else:
res = self.right.compare_trees(node.right)
return res

例如,我们想要比较树(3, 8, 10) 和树(3, 8, 11)


bst_compare.png

# root2 is the root of tree 2
root.compare_trees(root2)

当调用compare_trees()会发生以下的情况:

root 节点的compare_trees() 方法调用了,用来比较另一个树

root 节点有左child,所以调用了左child的compare_trees()方法

左子树比较完返回True

root 节点有右child,所以调用了右child的compare_trees()方法

右子树不同,返回False

整个compare_trees()返回False

Generator returning the tree elements one by one

有时候使用生成器繁护i数据是非常有用的,它节省内存不需要立刻产生一个完整的list。每次调用,都会返回下一个node的值。

为了做到这样的效果,使用yield的关键字,他会返回一个对象,并且停止。所以这个函数会在下一次调用继续执行,我们不能使用递归,使用栈(FILO)代替。

def tree_data(self):
stack = []
node = self
while stack or node:
if node: # 判断当前node是否为None
stack.append(node)
node = node.left
else: # 当前node为None, 那么yield返回数据
node = stack.pop()
yield node.data
node = node.right

让我们来看一下这些操作做了什么事情:


bst.png

1- The root node tree_data() method is called.
2- Node 8 is added to the stack. We go to the left child of 8.
3- Node 3 is added to the stack. We go to the left child of 3.
4- Node 1 is added to the stack. Node is set to None because there is no left child.
5- We pop a node which is Node 1. We yield it (returns 1 and stops here until tree_data() is called again.
6- tree_data() is called again because we are using it in a for loop.
7- Node is set to None because Node 1 doesn’t have a right child.
8- We pop a node which is Node 3. We yield it (returns 3 and stops here until tree_data() is called again.
…

Here you go, I hope you enjoyed this tutorial. Don’t hesitate to add comments if you have any feedback.

完整代码

class Node:
def __init__(self, data):
self.left = None # Node
self.right = None
self.data = data
def insert(self, data):
if self.data is None:
self.data = data
return
if data < self.data:
if self.left is None:
self.left = Node(data)
else:
self.left.insert(data) # 递归调用
elif data > self.data: # data 大于当前节点 应该放在右边
if self.right is None:
self.right = Node(data)
else:
self.right.insert(data)
def lookup(self, data, parent=None):
if data < self.data:
if self.left is None:
return None, None
return self.left.lookup(data, self)
elif data > self.data:
if self.right is None:
return None, None
return self.right.lookup(data, self)
else:
return self, parent
def delete(self, data):
node, parent = self.lookup(data)
if node is not None:
children_count = node.children_count()
else:
return
if children_count == 0:
if parent:
if parent.left is node: # 判断是否是parent的左节点
parent.left = None
else:
parent.right = None
del node
else:
self.data = None # root node
elif children_count == 1:
if node.left:
n = node.left
else:
n = node.right
if parent:
if parent.left is node: # 判断是否是parent的左节点
parent.left = n
else:
parent.right = n
del node
else: # root node
self.left = n.left
self.right = n.right
self.data = n.data
else:
# if node has 2 children
# 找到它的继承者
parent = node
successor = node.right # 因为左子树小于右子树
print(successor.data)
while successor.left:
parent = successor
successor = successor.left
# 替换接点数据为继承者(子节点)数据
node.data = successor.data
if parent.left == successor:
# print(parent.left.data, '---', successor.data)
parent.left = successor.right
else:
parent.right = successor.right
def children_count(self):
cnt = 0
if self.left:
cnt += 1
if self.right:
cnt += 1
return cnt
def print_tree(self):
if self.left:
self.left.print_tree()
print(self.data, end=' ')
if self.right:
self.right.print_tree()
def compare_trees(self, node):
if node is None: # 传入的节点为None
return False
if self.data != node.data: # 当前节点的数据和其它树的节点的数据不同
return False
res = True
if self.left is None:
if node.left: # 当前节点没有右节点 但是其它树节点有
return False
else:
res = self.left.compare_trees(node.left) # 比较下一个左节点的数据
if res is False:
return False
if self.right is None:
if node.right: # 逻辑同上 现在遍历右边的节点
return False
else:
res = self.right.compare_trees(node.right)
return res
def tree_data(self):
stack = []
node = self
while stack or node:
if node: # 判断当前node是否为None
stack.append(node)
node = node.left
else: # 当前node为None, 那么yield返回数据
node = stack.pop()
yield node.data
node = node.right
def main():
root = Node(8)
root.insert(3)
root.insert(10)
root.insert(1)
root.insert(6)
root.insert(4)
root.insert(7)
root.insert(14)
root.insert(13)
node, parent = root.lookup(6)
print(node.data, parent.data)
# root.delete(3) 1, 3, 4, 6, 7, 8, 10, 13, 14
print('*'*10)
root.print_tree()
print()
# 比较两个树
comp_root_1 = Node(8)
comp_root_1.insert(3)
comp_root_1.insert(10)
comp_root_2 = Node(8)
comp_root_2.insert(3)
comp_root_2.insert(11)
print(comp_root_1.compare_trees(comp_root_2))
# 打印元素
for i in root.tree_data():
print(i, end=' ')
print()
if __name__ == '__main__':
main()