都在代码里了


from math import log

def calcShannonEnt(dataSet):
	numEntries = len(dataSet)
	labelCounts = {}
	for featVec in dataSet:
		currentLabel = featVec[-1]
		if currentLabel not in labelCounts.keys():
			labelCounts[currentLabel] = 0
		labelCounts[currentLabel] += 1

	shannonEnt = 0.0
	for key in labelCounts:
		prob = float(labelCounts[key])/numEntries
		shannonEnt -= prob * log(prob, 2)
	return shannonEnt

def createDataSet():
	dataSet = [[1,1,'yes'],
			[1,1,'yes'],
			[1,0,'no'],
			[0,1,'no'],
			[0,1,'no']]
	labels = ['no surfacing', 'flippers']
	return dataSet, labels

def splitDataSet(dataSet, axis, value):
	retDataSet = []
	for featVec in dataSet:
		if featVec[axis] == value:
			reducedFeatVec = featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
	return retDataSet

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0])-1 # 获取我们样本集中的某一个样本的特征数
   								    #(因为每一个样本的特征数是相同的,相当
   								    # 于这个代码就是我们可以作为分类依据的所
   								    # 有特征个数)我们的样本最后一列是样本所
   								    # 属的类别,所以要减去类别信息,在我们的
   								    # 例子中特征数就是2

    baseEntropy = calcShannonEnt(dataSet) #计算样本的初始香农熵
    bestInfoGain =0.0 #初始化最大信息增益
    bestFeature = -1  #和最佳划分特征

    for i in range(numFeatures): # range(2)那么i的取值就是0,1。 在这里i表示
    							 # 的我们的第几个特征
        featList = [sample[i] for sample in dataSet]
        # 我们首先遍历整个数据集,首先得到第一个特征值可能的取值,然后把它赋值给
        # 一个链表,我们第一个特征值取值是[1,1,1,0,0],其实只有【1,0】两个取值
        uniqueVals = set(featList)# 我们使用集合这个数据类型删除多余重复的原始使
        						  # 得其中只有唯一的值。
        #执行的结果如下所示:
        
        # In [8]: featList=[1,1,1,0,0]

        # In [9]: uniqueVals=set(featList)

        # In [10]: uniqueVals
        # Out[10]: {0, 1}

   
        newEntropy = 0.0
        for value in uniqueVals: # uniqueVals中保存的是我们某个样本的特征值的所有
        						 # 的取值的可能性
            subDataSet = splitDataSet(dataSet,i,value)
            # 在这里划分数据集,比如说第一个特征的第一个取值得到一个子集,第一个
            # 特征的特征又会得到另一个特征。当然这是第二次循环
            prob = len(subDataSet)/float(len(dataSet))# 我们以第一个特征来说明,根
            										  # 据第一个特征可能的取值划分
            										  # 出来的子集的概率
            newEntropy += prob * calcShannonEnt(subDataSet)

        infoGain = baseEntropy - newEntropy # 计算出信息增益
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature

def majorityCnt(classList): # 传入的参数是已经划分完所有特征之后剩余的数据集,
#例如[['yes'],['yes'],['maybe']]
    classCount={} #创建一个字典
    for vote in classList:  
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
        # 根据上述的语句,以及我们的例子,我们最终可以得到的结果如下: 
        # {'yes':2,'maybe':1}
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

# 使用字典iteritems
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): 
        return classList[0] #stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: 
        return majorityCnt(classList)
		
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    subLabels = labels[:]   #copy all of labels, so trees don't mess up
   									    #existing labels
    del(subLabels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals: 
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

treePlotter.py

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点的属性。boxstyle为文本框的类型,
                                                    # sawtooth是锯齿形,fc是边框线粗细
# 可以写为decisionNode={boxstyle:'sawtooth',fc:'0.8'}
leafNode = dict(boxstyle="round4", fc="0.8") #决策树叶子节点的属性
arrow_args = dict(arrowstyle = "<-") #箭头的属性

# def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#     createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
#                             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#     #nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的箭尾点,xy是箭头尾的坐标
#     #xytext -- 为注解内容位置坐标,当该值为None时,注解内容放置在xy处
#     #xycoords和textcoords是坐标xy与xytext的说明(按轴坐标),若textcoords=None,则默认textcoords与xycoords相同,若都未设
#     #置,默认为data
#     #va/ha设置节点框中文字的位置,va为纵向取值为(u'top', u'bottom', u'center', u'baseline'),ha为横向取值
#     #为(u'center', u'right', u'left')

# def createPlot():
#     fig = plt.figure(1, facecolor = 'white') #创建一个画布,背景为白色
#     fig.clf() #画布清空
#     #ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
#     createPlot.ax1 = plt.subplot(111, frameon = True) #frameon表示是否绘制坐标轴矩形
#     plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#     plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#     plt.show()

# createPlot()

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 画出一个以centerPt为中心的结点 和 一个以parentPt为箭尾,centerPt为箭头的箭,并且结点中会有文字nodeTxt
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',\
             xytext=centerPt, textcoords='axes fraction',\
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

# 在点cntrPt和parentPt之间写点文字txtString
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):#这里nodeTxt是为箭头准备的,也就是写在箭头中间的文字
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  #cntr=center
    plotMidText(cntrPt, parentPt, nodeTxt)  # 这里看createPlot函数里的plotTree(inTree, (0.5,1.0), '')
                                            # 得出这里cntrPt、parentPt是同一个点(0.5,1.0),注释nodeTxt为空''
                                            # 所以在这个时候是没有效果的
    plotNode(firstStr, cntrPt, parentPt, decisionNode) # 这里firstStr是myTree这颗字典树的第一个决策结点的属性,比如:'根蒂=?'
                                                       # 同样,cntrPt、parentPt是同一个点(0.5,1.0),所以这里得出的效果是:
                                                       # 看不见箭头的一个带有firstStr文字的decisionNode类型的结点
    secondDict = myTree[firstStr]  #第一个子树secondDict
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  #结点高度降一度
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #从这里开始,箭头中心需要留下父节点到子节点的属性值str(key)
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #递归结束了,结点高度也要回归一度
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

# myTree = retrieveTree(1)
# createPlot(myTree)