import math
import pickle
from matplotlib import pyplot as plt
def calc_shang(dataset: list):
"""
计算给定数据集的香农熵
:param dataset:
:return:
"""
length = len(dataset)
label_count_map = {}
for item in dataset:
current_label = item[-1]
if current_label not in label_count_map:
label_count_map[current_label] = 0
label_count_map[current_label] += 1
shang = 0.0
for label, count in label_count_map.items():
prob = count / length
shang += prob * (-1 * math.log(prob, 2))
return shang
def create_dataset():
dataset = [
[1, 1, "yes"],
[1, 1, "yes"],
[1, 0, "no"],
[0, 1, "no"],
[0, 1, "no"]
]
labels = ["no surfacing", "flippers"]
return dataset, labels
def split_dataset(dataset, axis, value):
new_dataset = []
for item in dataset:
if item[axis] == value:
reduced_item = item[:axis]
reduced_item.extend(item[axis + 1:])
new_dataset.append(reduced_item)
return new_dataset
def choose_best_feature(dataset):
num = len(dataset[0]) - 1
shang = calc_shang(dataset)
best_info_gain = 0
best_feature = -1
for i in range(num):
feat_list = [_[i] for _ in dataset]
unique_list = set(feat_list)
_shang = 0
for feat in unique_list:
sub_dataset = split_dataset(dataset, i, feat)
prob = len(sub_dataset) / len(dataset)
_shang += prob * calc_shang(sub_dataset)
info_gain = shang - _shang
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
def classify(class_list):
class_count_map = {}
for item in class_list:
if item not in class_count_map:
class_count_map[item] = 0
class_count_map[item] += 1
sorted_class_count_map = sorted(class_count_map.items(), key=lambda x: x[1], reverse=True)
return sorted_class_count_map[0][0]
def create_tree(dataset, labels):
class_list = [_[-1] for _ in dataset]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
best_feature = choose_best_feature(dataset)
best_class_label = labels[best_feature]
tree = {best_class_label: {}}
del labels[best_feature]
feat_values = [_[best_feature] for _ in dataset]
unique_values = set(feat_values)
for value in unique_values:
sub_labels = labels[:]
tree[best_class_label][value] = create_tree(split_dataset(dataset, best_feature, value), sub_labels)
return tree
def plot_tree(tree, root_name):
def _plot_tree(ax, tree, parent_name, parent_x, parent_y, dx, dy):
if parent_name and parent_x == 0 and parent_y == 0:
ax.text(0, 0, parent_name, ha='center', va='center', bbox=dict(facecolor='white', edgecolor='black'))
if isinstance(tree, dict):
# 遍历字典中的每个键值对
for edge_label, child in tree.items():
# 计算子节点的位置
child_x = parent_x - dx / 2 if edge_label == 0 else parent_x + dx / 2
child_y = parent_y - dy
if isinstance(child, dict):
child_name = list(child.keys())[0]
else:
child_name = child
# 绘制边和边的描述
ax.plot([parent_x, child_x], [parent_y, child_y], 'k-')
mid_x = (parent_x + child_x) / 2
mid_y = (parent_y + child_y) / 2
ax.text(mid_x, mid_y, str(edge_label), ha='center', va='center', fontsize=8,
bbox=dict(facecolor='yellow', edgecolor='black'))
# 绘制子节点
ax.text(child_x, child_y, child_name, ha='center', va='center',
bbox=dict(facecolor='white', edgecolor='black'))
# 递归绘制子树
if isinstance(child, dict):
_plot_tree(ax, child[child_name], child_name, child_x, child_y, dx / 2, dy)
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_xlim(-1, 1)
ax.set_ylim(-1.5, 0.5)
ax.axis('off')
_plot_tree(ax, tree[root_name], root_name, 0, 0, 1, 0.5)
plt.show()
def classify_tree(tree: dict, labels: list, test_vec):
first_str = list(tree.keys())[0]
second_dict = tree[first_str]
feat_index = labels.index(first_str)
class_label = ""
for key, value in second_dict.items():
if test_vec[feat_index] == key:
if isinstance(value, dict):
class_label = classify_tree(value, labels, test_vec)
else:
class_label = value
return class_label
def store_tree(tree: dict, file_path: str):
with open(file_path, "wb") as f:
pickle.dump(tree, f)
def grab_tree(file_path):
with open(file_path, "rb") as f:
return pickle.load(f)
if __name__ == '__main__':
mat, labels = create_dataset()
tree = create_tree(dataset=mat, labels=labels)
plot_tree(tree, 'no surfacing')
其他决策树示例或者基于主流机器学习框架实现的决策树代码地址:
https://gitee.com/navysummer/machine-learning/tree/master/decision_tree