一、基础知识准备:

1.标称型 & 数值型

标称型:标称型目标变量的结果只有在有限的目标集中取值,如TrueFalse(标称型目标变量主要用于分类)

数值型:数值型目标变量则可以从无限的数值集合中取值,如0.10042.001等(数值型目标变量主要用于回归分析)

2.信息熵 & 信息增益

信息熵:度量数据集合无序程度的量
信息增益:信息熵(划分数据集前) - 信息熵(划分数据集后)

3.ID3算法

1). ID3算法是一种 贪心算法,用来构造 决策树
2). ID3算法的核心: 信息熵
3).ID3算法:以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样本`

4.决策树算法的优缺点:

优点:计算复杂度不高,输出结果易于理解,对于中间值的缺失不敏感,可以处理不相关特征数据
缺点:可能产生过度匹配的问题
试用类型: (离散化的)数值型标称型

二、具体决策树的算法:代码及注释

1.算法流程

①收集数据:可以使用任何方法
②准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化
③分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
④训练算法:构造树的数据结构
⑤测试算法:使用经验树计算错误率
⑥使用算法:此算法可适用于任何无监督学习算法,而使用决策树可以更好的理解数据的内在含义

2.具体代码:

from math import log
import operator

#createDataSet()只用作生成测试算法的样本实例,在实际算法中替换为具体的数据集
def createDataSet(): 
    dataSet = [[1, 1,' yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels
#计算香农熵(即信息熵)
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)    #计算dataSet的长度
    labelCounts = {}             #存储dataSet中所有的标签分类
    for featVec in dataSet:      #遍历dataSet,添加所有的标签分类到labelCounts字典中
        currentLabel = featVec[-1]  #获得dataSet的每列标签,作为当前标签
        if currentLabel not in labelCounts.keys(): #如果当前标签不在labelCounts字典的key中,就加进去,并初始化其标签的value为0
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 #如果当前标签已经存在了就直接加1,统计出现频率【注意这里的缩进!!!】
    shannonEnt = 0.0
    for key in labelCounts:     #计算熵
            prob = float(labelCounts[key])/numEntries
            shannonEnt -= prob * log(prob,2)
    return shannonEnt
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value): #输入信息:数据集,轴(分类特征),值(轴的返回值)
    retDataSet = []           #准备一个list作为返回dataSet
    for featVec in dataSet:   #遍历dataSet
        if featVec[axis] == value: #如果当前轴的值等于设定的值,就将它剔除
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:]) #【注意extend()和append()的区别】
            retDataSet.append(reducedFeatVec)
    return retDataSet
#选择最佳的数据集划分特征方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1     #计算一个样本具有的特征数
    baseEntropy = calcShannonEnt(dataSet) #计算dataSet初始香农熵(即数据集合原始的最大 无序程度)
    bestInfoGain = 0.0;bestFeature = -1   #初始化熵增益和最佳的划分特征
    for i in range(numFeatures):          #遍历特征集
        featList = [example[i] for example in dataSet] #针对某特征i记录所有样本中的该特征
        uniqueVals = set(featList) #采用“集合”来记录所有唯一的特征(利用集合内无重复元素)
        newEntropy = 0.0
        for value in uniqueVals:   #遍历唯一的特征集,并计算每个特征的熵,且计算后使用splitDataSet()剔除该特征
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy #计算熵的增益,选择增益最大的作为划分特征(增益越大,说明该划分更有效)
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
#当遇到最佳划分特征不唯一时,采用“投票表决”,这部分代码相当于一个模块,类似情景下都可以使用
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in  classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1      #【注意这里的缩进!!!】
    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
#创建树的数据结构
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList): #如果所有类标签完全相同,则直接返回标签
        return classList[0]
    if len(dataSet[0]) == 1:   #如果使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,则挑选出现次数最多的类别返回
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat]) 
    featValues = [example[bestFeat] for example in dataSet] #用featValues list存储dataSet中最佳划分特征包含的所有属性值
    uniqueVals = set(featValues)  #以“集合”的形式存储特征值,得到唯一属性值的集合
    for value in uniqueVals:      #遍历唯一属性值的集合,拷贝所有标签,并递归地创建树的分支
        subLabels = labels[:]
        myTree[bestFeatLabel][valshiue] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    #最终得到的“树”是一个“嵌套字典”,包含了许多子节点/叶子节点信息的字典
    return myTree
#使用决策树的分类函数--递归函数
def classify(inputTree, featLabels, testVec): #输入信息:决策树、特征标签、特征标签上的值
    firstStr = inputTree.keys()[0]          #获得输入树的第一个标签
    secondDict = inputTree[firstStr]        #递归得到新的树的字典secondDict
    featIndex = featLabels.index(firstStr)  #使用index方法查找当前列表中第一个匹配第一个标签firstStr变量的元素
    for key in secondDict.keys():           #遍历新的树的字典secondDict的键值
        if testVec[featIndex] == key:       #比较textVecs变量中的值与树节点的值
            if type(secondDict[key]).__name__ == 'dict':  #如果新的树的字典secondDict的键值仍为字典,则是判断节点,继续递归调用
                classLabel = classify(secondDict[key],featLabels,testVec)
        else:     #如果新的树的字典的键值不为字典了,说明是到达叶子节点,直接返回叶子节点的标签
                classLabel = secondDict[key] 
    return classLabel
 ```
 ```
 #使用pickle模块将决策树分类器存储在硬盘上,这样每次执行分类函数classify()的时候直接调用已经构造好的决策树                                                          
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

三、使用matplotlib绘制决策树

这部分不用重点关注,用到时多查http://matplotlib.org/即可

import matplotlib.pyplot as plt

#define box and arrow formattings:boxstyle/arrow defines the shape; fc defines the color depth 
decisionNode = dict(boxstyle = "sawtooth",fc = "0.8")
leafNode = dict(boxstyle = "round4",fc = "0.8")
arrow_args = dict(arrowstyle="<-")

#draw annotations with arrows
        #xy=positon of element to annotate.(end position of arrow)
    #xycords = string that indicates what type od coordinates 'xy' is.
    ######(eg.'figure points','figure pixels','figure fraction','axes fraction(0,0 is lower left of axes and 1,1 is upper right)'...http://matplotlib.org/api/text_api.html#matplotlib.text.Annotation...)
    #xytext = position of label in the box
    #textcoord = string that indicates what type of coordinates
    ######(eg.like above)
    #bbox = 
    # arrowprobs = 
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords = 'axes fraction',
        xytext = centerPt,textcoords = 'axes fraction',va = "center",ha = "center",
        bbox = nodeType,arrowprops = arrow_args)

#the key of how to draw figure
def createPlot():
    #new a figure
    fig = plt.figure(1,facecolor = 'white')
    #clear the figure
    fig.clf()
    #create a subplot...frameon = True/False -> Display or not
    createPlot.ax1 = plt.subplot(111,frameon = False)
    plotNode(U'decision_node',(0.5,0.1),(0.1,0.5),decisionNode)
    plotNode(U'leaf_node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #test to see if the nodes are dictonaires, if not they are leaf nodes
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #test to see if the nodes are dictonaires, if not they are leaf nodes
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

#this fuction is for test,pre-define a Tree
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#add text info between 'parentnode' and 'chilnode'
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid,yMid,txtString,va ="center",ha = "center",rotation = 30)

#if the first key tells you what feat was split on
def plotTree(myTree,parentPt,nodeTxt):
    #this determines the x width of this tree
    numLeafs = getNumLeafs(myTree)
    #this determines the y height of this tree
    depth = getTreeDepth(myTree)
    #the text label for this node should be this
    firstStr = myTree.keys()[0]
    #"plotTree.totalW" as a global varible to store weight of the tree
    #"plotTree.totalD" as a global varible to store depth of the tree
    #"plotTree.xOff"/"plotTree.yOff" as global varibles to trace the node position
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText (cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    #decrease 'yOff'
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        #test to see if the nodes are dictonaires, if not they are leaf nodes   
        if type(secondDict[key]).__name__ == 'dict':
            #recursion
            plotTree(secondDict[key], cntrPt, str(key))
        #if it's a leaf node, then print the leaf node
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict


#creat the print area
def createPlot(inTree):
    fig = plt.figure(1,facecolor = 'white')
    fig.clf()
    #"xticks"/"yticks" get x-limits/y-limits of current tick position s and labels
    axprops = dict(xticks = [], yticks =[])
    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 /plotTree.totalW;plotTree.yOff = 1.0;
    plotTree(inTree,(0.5,1.0),'')
    plt.show()