3.2 在python中使用matplotlib注解绘制树形图

上节学习了如何从数据集创建数,然而字典的表示形式非常不易于理解,而且直接绘制图像也比较困难。本节将使用matplotlib库创建树形图。决策树的主要优点就是直观易于理解,如果不能将其直观地显示出来,就无法发挥其优势。python并没有提供绘制树的工具,因此必须自己绘制树形图,本节将学习如何编写代码绘制下图所示的决策树:

python绘制决策树对string特征进行划分 用python画决策树_3.决策树(2)


3.2.1matplotlib注解

matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注释。注解通常用于解释数据的内容。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带箭头的划线工具,可以在其它恰当的地方指向数据位置,并在此处添加描述信息,解释数据内容。

matplotlib的注解功能可以对文字着色并提供多种形状以供选择,还可以反转箭头,将它指向文本框而不是数据点。创建名为treePlotter.py的新文件,输入下面的程序代码:

##使用文本注解绘制树节点
import matplotlib.pyplot as plt
#定义文本框和箭头格式
decisionNode=dict(boxstyle='sawtooth',fc='0.8')
leafNode=dict(boxstyle='round4',fc='0.8')
arrow_args=dict(arrowstyle='<-')
#绘制带箭头的注解
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
                            xytext=centerPt,textcoords='axes fraction',
                            va='center',ha='center',bbox=nodeType,arrowprops=arrow_args)
def createPlot():
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    createPlot.ax1=plt.subplot(111,frameon=False)
    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

代码中plotNode()函数执行实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。python语言中所有的变量默认都是全局有效的,只要清楚的知道当前代码的主要功能,并不会引入太大的麻烦。最后定义的createPlot()函数,是这段代码的核心。createPlot()函数首先创建了一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面将用这两个节点绘制树形图。
为了测试上述代码的实际输出结果,输入:

if __name__=="__main__":
    createPlot()

输出结果:

python绘制决策树对string特征进行划分 用python画决策树_3.决策树(2)_02


3.2.2构造注解树

绘制一棵完整的树需要一些技巧。虽然有x,y坐标,但是如何放置树节点却是个问题。必须知道有多少个叶节点,以便可以正确确定x轴的长度;还需要知道树有多少层,以便可以正确确定y轴的高度。这里定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的数目和树的层数,如下代码所示,将下面两个函数添加到文件treePlotter.py中。

##获取叶节点的数目和树的层数
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    #测试节点的数据类型是否为字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=="dict":
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth=0
    firstStr=list(myTree.keys())[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if thisDepth>maxDepth:maxDepth=thisDepth
    return maxDepth

上述程序中的两个函数具有相同的结构,后面也将使用到这两个函数。这里使用的数据结构说明了如何在python字典类型中存储树信息。第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字出发,可以遍历整棵树的所有子节点。使用python提供的type()函数可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用getNumleafs()函数。getNumLeafs()函数遍历整棵树,累计叶子节点的个数,并返回该数值。第二个函数getTreeDepth()计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。为了节省时间,函数retrieveTree输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦。
添加下面的代码到文件treePlotter中:

def retrieveTree(i):
    listOfTrees=[{'no surfing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                 { 'no surfing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]
    return  listOfTrees[i]

测试代码:

if __name__=="__main__":
    myTree=retrieveTree(0)
    print(myTree)
    print(getNumLeafs(myTree))
    print(getTreeDepth(myTree))

测试结果:

python绘制决策树对string特征进行划分 用python画决策树_机器学习实战_03


函数retrieveTree()主要用于测试,返回预定义的树结构。上述命令中调用getNumLeafs()函数返回值为3,等于树0的叶子节点数;调用getTreeDepths()函数也能够正确返回树的层数。

下面绘制一棵完整的树,将以下代码添加到treePlotter.py文件中。此处还需要更新前文定义的函数createPlot():

#在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0]
    yMid=(parentPt[1]-cntrPt[1])/2+cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree,parentPt,nodeTxt):
    #计算宽和高
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=list(myTree.keys())[0]
    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
    #标记节点属性值
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    #减小y偏移
    plotTree.yOff=plotTree.yOff-1/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff=plotTree.xOff+1/plotTree.totalW
            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
    plotTree.yOff=plotTree.yOff+1/plotTree.totalD

def createPlot(inTree):
 fig=plt.figure(1,facecolor='white')
 fig.clf()
 axprops=dict(xticks=[],yticks=[])
 createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
 plotTree.totalW=float(getNumLeafs(inTree))
 plotTree.totalD=float(getTreeDepth(inTree))
 plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1;
 plotTree(inTree,(0.5,1.0),' ')
 plt.show()

函数creatPlot()是主函数,它调用了plotTree(),函数plotTree又依次调用了前面介绍的函数和plotMidText()。绘制树形图的很多工作都是在函数plotTree()中完成的,函数plotTree()首先计算树的宽和高。全局变量plotTree.totalW存储树的宽度,全局变量plotTree.totalD存储树的深度,使用这两个变量计算树节点的摆放位置,可以将树绘制在水平方向和垂直方向的中心位置。与getNumLeafs()和getTreeDepth()类似,函数plotTree()也是递归函数。树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不仅仅是它子节点的中间。同时使用两个全局变量plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。另一个需要说明的问题是,绘制图形的x轴有效范围是0.0到1.0,y轴有效范围也是0.0~1.0.通过计算树包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置,也就是说,按照叶子节点的数目将x轴划分为若干部分。按照图形比例绘制树形图的最大好处是无需关心实际输出图形的大小,一旦图形大小发生了变化,函数会自动按照图形大小重新绘制。如果以像素为单位绘制图形,则缩放图形就不是一件简单的工作。
接着,绘制子节点具有的特征值,或者沿此分支向下的数据实例必须具有的特征值。使用函数plotMidText()计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息。
然后,按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点,这些节点即可以是叶子节点也可以是判断节点,此处需要只保存绘制图形的轨迹。因为是自顶向下绘制图形,因此需要依次递减y坐标值,而不是递增y坐标值。然后程序采用函数getNumLeafs()和getTreeDepth()以相同的方式递归遍历整棵树,如果节点是叶子节点则在图形上画出叶子节点,如果不是叶子节点则递归调用plotTree()函数。在绘制了所有子节点之后,增加全局变量Y的偏移。
程序的最后一个函数creatPlot(),它创建绘图区,计算树形图的全局尺寸,并调用递归函数plotTree()。
验证输出结果:

if __name__=="__main__":
    myTree=retrieveTree(0)
    createPlot(myTree)

输出结果:

python绘制决策树对string特征进行划分 用python画决策树_matplotlib注解绘制树形图_04


变更字典,重新绘制树形图:

if __name__=="__main__":
    myTree=retrieveTree(0)
    myTree['no surfing'][3]='maybe'
    createPlot(myTree)

输出结果:

python绘制决策树对string特征进行划分 用python画决策树_机器学习实战_05


3.3 测试和存储分类器

本节将使用决策树构建分类器,以及实际应用中如何存储分类器。

3.3.1测试算法:使用决策树执行分类

依靠训练数据构造了决策树之后,可以将它用于实际数据的分类。在执行数据分类时,需要决策树以及用于构造树的标签向量。然后程序比较测试数据和决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。

为了验证算法的实际效果,将下列程序代码添加到文件tree.py中。

def classify(inputTree,featLabels,testVec):
    firstStr=list(inputTree.keys())[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel

上述代码定义的函数也是一个递归函数,在存储带有特征的数据会面临一个问题:程序无法确定特征在数据集中的位置,例如前面例子的第一个用于划分数据集的特征是no surfing属性,但是在实际数据集中该属性存储在哪个位置?是第一个属性还是第二个属性?特征标签列表将帮助程序处理这个问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素。然后代码递归遍历整棵树,比较testVec变量中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签。
测试代码:

import tree #见上一节
if __name__=="__main__":
    myDat,labels=tree.creatDataSet()
    myTree=retrieveTree(0)
    print(labels)
    print(myTree)
    print(classify(myTree,labels,[1,0]))
    print(classify(myTree,labels,[1,1]))

测试结果:

python绘制决策树对string特征进行划分 用python画决策树_机器学习实战_06


由输出结果可以看出。第一个节点名为no surfing,它有两个子节点:一个是名字为0的叶子节点,类标签为no;另一个是名为flippers的判断节点,此处进入递归调用,flippers节点有两个节点。以前绘制的树形图和此处代表树的数据结构完全相同。

现在已经创建了使用决策树的分类器,但是每次使用分类器时,必须重新构造决策树,下面将介绍如何在硬盘上存储决策树分类器。

3.3.2使用算法:决策树的存储

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用python模块pickle序列化对象,如下列程序所示。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。

#使用pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb+')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr=open(filename,'rb+')
    return pickle.load(fr)

测试代码效果:

if __name__=="__main__":
    #myDat,labels=tree.creatDataSet()
    myTree=retrieveTree(0)
    storeTree(myTree,'classifierStorage.txt')
    print(grabTree('classifierStorage.txt'))

测试结果:

python绘制决策树对string特征进行划分 用python画决策树_3.决策树(2)_07


通过上面的代码,可以将分类器存储在硬盘上,而不是每次对数据分类时重新学习一遍,这也是决策树的优点之一,k-近邻算法就无法持久化分类器。可以预先提炼并存储数据集中包含的知识信息,在需要对事物进行分类时再使用这些知识。

3.4 示例:使用决策树预测隐形眼镜类型

本节将通过一个例子讲解决策树如何预测患者需要佩戴的隐形眼镜类型。使用小数据集,就可以利用决策树学到很多知识:眼科医生是如何判断患者需要佩戴的镜片类型;一旦理解了决策树的工作原理,甚至可以帮助人们判断需要佩戴的镜片类型。

示例:使用决策树预测隐形眼镜类型

(1)收集数据:提供的文本文件。

(2)准备数据:解析tab键分隔的数据行。

(3)分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图。

(4)训练算法:使用3.1节中的createTree()函数。

(5)测试算法:编写测试函数验证决策树可以正确分类给定的数据实例。

(6)使用算法:存储树的数据结构,以便下次使用时无需重新构造树。

隐形眼镜数据集(数据来源)是非常著名的数据集,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。数据来源于UCI数据库,为了更容易显示数据,按照文件说明对数据做简单的更改:

文件说明:

python绘制决策树对string特征进行划分 用python画决策树_matplotlib注解绘制树形图_08


更改后的数据:

python绘制决策树对string特征进行划分 用python画决策树_matplotlib注解绘制树形图_09


数据存储在源代码下载路径的文本文件中,新建lensesChoose.py文件进行操作,加载数据集:

import tree
fr=open('lenses.data')
lenses=[inst.strip().split(' ') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=tree.createTree(lenses,lensesLabels)
print(lensesTree)

结果:
{‘tearRate’: {‘normal’: {‘astigmatic’: {‘no’: {‘age’: {‘pre-presbyopic’: ‘soft’, ‘young’: ‘soft’, ‘presbyopic’: {‘prescript’: {‘myope’: ‘nolenses’, ‘hypermetrope’: ‘sofy’}}}}, ‘yes’: {‘prescript’: {‘myope’: ‘hard’, ‘hypermetrope’: {‘age’: {‘pre-presbyopic’: ‘nolenses’, ‘young’: ‘hard’, ‘presbyopic’: ‘nolenses’}}}}}}, ‘reduce’: ‘nolenses’}}
采用文本方式很难分辨出决策树的模样,调用以下命令绘制树形图:

import treePlotter
treePlotter.createPlot(lensesTree)

由ID3算法产生的决策树:

python绘制决策树对string特征进行划分 用python画决策树_子节点_10


沿着决策树的不同分支,可以得到不同患者需要佩戴的隐形眼镜类型。从上图可以发现,医生最多需要问四个问题就能确定患者需要佩戴哪种类型的隐形眼镜。

上图所示的决策树非常好地匹配了实验数据,然而这些匹配选项可能太多了。将这种问题称之为过度匹配(overfitting)。为了减少过度匹配问题,可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其它叶子节点中。

本章使用的算法称为ID3算法,它是一个好的算法但不完美。ID3算法无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题。

3.5 本章小结

决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,首先需要测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3算法可以用于划分标称型数据集。构造决策树时,通常采用递归方法将数据集转化为决策树。一般并不构造新的数据结构,而是使用python语言内嵌的数据结构字典存储树节点信息。

使用matplotlib的注解功能,可以将存储的树结构转化为容易理解的图形。python语言的pickle模块可用于存储决策树的结构。隐形眼镜的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。

还有其它的决策树的构造算法,最流行的是C4.5和CART。