from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def createDataSet():
dataSet = [[1,1,'yes'],
labels = ['no surfacing', 'flippers']
return dataSet, labels
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0])-1 # 获取我们样本集中的某一个样本的特征数
# 于这个代码就是我们可以作为分类依据的所
# 有特征个数)我们的样本最后一列是样本所
# 属的类别,所以要减去类别信息,在我们的
# 例子中特征数就是2
baseEntropy = calcShannonEnt(dataSet) #计算样本的初始香农熵
bestInfoGain =0.0 #初始化最大信息增益
bestFeature = -1 #和最佳划分特征
for i in range(numFeatures): # range(2)那么i的取值就是0,1。 在这里i表示
# 的我们的第几个特征
featList = [sample[i] for sample in dataSet]
# 我们首先遍历整个数据集,首先得到第一个特征值可能的取值,然后把它赋值给
# 一个链表,我们第一个特征值取值是[1,1,1,0,0],其实只有【1,0】两个取值
uniqueVals = set(featList)# 我们使用集合这个数据类型删除多余重复的原始使
# 得其中只有唯一的值。
# In [8]: featList=[1,1,1,0,0]
# In [9]: uniqueVals=set(featList)
# In [10]: uniqueVals
# Out[10]: {0, 1}
newEntropy = 0.0
for value in uniqueVals: # uniqueVals中保存的是我们某个样本的特征值的所有
# 的取值的可能性
subDataSet = splitDataSet(dataSet,i,value)
# 在这里划分数据集,比如说第一个特征的第一个取值得到一个子集,第一个
# 特征的特征又会得到另一个特征。当然这是第二次循环
prob = len(subDataSet)/float(len(dataSet))# 我们以第一个特征来说明,根
# 据第一个特征可能的取值划分
# 出来的子集的概率
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 计算出信息增益
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList): # 传入的参数是已经划分完所有特征之后剩余的数据集,
classCount={} #创建一个字典
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 根据上述的语句,以及我们的例子,我们最终可以得到的结果如下:
# {'yes':2,'maybe':1}
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 使用字典iteritems
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] #stop splitting when all of the classes are equal
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
subLabels = labels[:] #copy all of labels, so trees don't mess up
#existing labels
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策节点的属性。boxstyle为文本框的类型,
# sawtooth是锯齿形,fc是边框线粗细
# 可以写为decisionNode={boxstyle:'sawtooth',fc:'0.8'}
leafNode = dict(boxstyle="round4", fc="0.8") #决策树叶子节点的属性
arrow_args = dict(arrowstyle = "<-") #箭头的属性
# def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
# va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# #nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的箭尾点,xy是箭头尾的坐标
# #xytext -- 为注解内容位置坐标,当该值为None时,注解内容放置在xy处
# #xycoords和textcoords是坐标xy与xytext的说明(按轴坐标),若textcoords=None,则默认textcoords与xycoords相同,若都未设
# #置,默认为data
# #va/ha设置节点框中文字的位置,va为纵向取值为(u'top', u'bottom', u'center', u'baseline'),ha为横向取值
# #为(u'center', u'right', u'left')
# def createPlot():
# fig = plt.figure(1, facecolor = 'white') #创建一个画布,背景为白色
# fig.clf() #画布清空
# #ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以
# createPlot.ax1 = plt.subplot(111, frameon = True) #frameon表示是否绘制坐标轴矩形
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
# createPlot()
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 画出一个以centerPt为中心的结点 和 一个以parentPt为箭尾,centerPt为箭头的箭,并且结点中会有文字nodeTxt
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
xytext=centerPt, textcoords='axes fraction',\
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 在点cntrPt和parentPt之间写点文字txtString
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):#这里nodeTxt是为箭头准备的,也就是写在箭头中间的文字
numLeafs = getNumLeafs(myTree) #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #cntr=center
plotMidText(cntrPt, parentPt, nodeTxt) # 这里看createPlot函数里的plotTree(inTree, (0.5,1.0), '')
# 得出这里cntrPt、parentPt是同一个点(0.5,1.0),注释nodeTxt为空''
# 所以在这个时候是没有效果的
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 这里firstStr是myTree这颗字典树的第一个决策结点的属性,比如:'根蒂=?'
# 同样,cntrPt、parentPt是同一个点(0.5,1.0),所以这里得出的效果是:
# 看不见箭头的一个带有firstStr文字的decisionNode类型的结点
secondDict = myTree[firstStr] #第一个子树secondDict
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #结点高度降一度
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key)) #从这里开始,箭头中心需要留下父节点到子节点的属性值str(key)
else: #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #递归结束了,结点高度也要回归一度
#if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
return listOfTrees[i]
# myTree = retrieveTree(1)
# createPlot(myTree)