from math import log
import operator
import treePlotter
#计算给定数据集的香农熵的函数
def calcShannonEnt(dataSet):
    # 求list的长度,表示计算参与训练的数据量
    numEntries=len(dataSet)
    labelCounts={}
     # 计算分类标签label出现的次数
    for featVec in dataSet:
    # 将当前实例的标签存储,即每一行数据的最后一个数据代表的是标签
        currentLabel=featVec[-1]
    # 为所有可能的分类创建字典,如果当前的键值不存在,则扩展字典并将当前键值加入字典。每个键值都记录了当前类别出现的次数。
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    # 对于 label 标签的占比,求出 label 标签的香农熵
    for key in labelCounts:
        # 使用所有类标签的发生频率计算类别出现的概率。
        prob=float(labelCounts[key])/numEntries
         # 计算香农熵,以 2 为底求对数
        shannonEnt-=prob*log(prob,2)
    return shannonEnt
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        # axis列为value的数据集【该数据集需要排除index列】
        # 判断axis列的值是否为value
        if featVec[axis]==value:
            # [:axis]表示前axis行,即若 axis 为2,就是取 featVec 的前axis行
            reduceFeatVec=featVec[:axis]
             # [axis+1:]表示从跳过axis的axis+1行,取接下来的数据
            # 收集结果值axis列为value的行【该行需要排除axis列】
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
     # 求第一行有多少列的 Feature, 最后一列是label列嘛
    numFeatures=len(dataSet[0])-1
     # 数据集的原始信息熵
    baseEntropy=calcShannonEnt(dataSet)
     # 最优的信息增益值, 和最优的Featurn编号
    bestInfoGain=0.0;bestFeature=-1
     #迭代所有特征
    for i in range(numFeatures):
        #创建list# 获取对应的feature下的所有数据
        featList=[example[i] for example in dataSet]
        # 获取剔重后的集合,使用set对list数据进行去重
        uniqueVals=set(featList)
         # 创建一个临时的信息熵
        newEntropy=0.0
         # 遍历某一列的value集合,计算该列的信息熵
         # 遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,计算数据集的新熵值,并对所有唯一特征值得到的熵求和。
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)
             # 计算概率
            prob=len(subDataSet)/float(len(dataSet))
             # 计算信息熵
            newEntropy+=prob*calcShannonEnt(subDataSet)
        # gain[信息增益]: 划分数据集前后的信息变化, 获取信息熵最大的值
        # 信息增益是熵的减少或者是数据无序度的减少。最后,比较所有特征中的信息增益,返回最好特征划分的索引值。
        infoGain=baseEntropy-newEntropy
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature

def majorityCnt(classList):
    classCount={}
    for vote in classCount:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote]+=1
        sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]
#训练算法:构造树的数据结构
def createTree(dataSet,labels):
    # 如果数据集的最后一列的第一个值出现的次数=整个集合的数量,也就说只有一个类别,就只直接返回结果就行
    # 第一个停止条件:所有的类标签完全相同,则直接返回该类标签。
    # count() 函数是统计括号中的值在list中出现的次数
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    # 如果数据集只有1列,那么最初出现label次数最多的一类,作为结果
    # 第二个停止条件:使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。
    if len(dataSet[0])==1:
        return majorityCnt(classList)
     # 选择最优的列,得到最优列对应的label含义
    bestFeat=chooseBestFeatureToSplit(dataSet)
    # 获取label的名称
    bestFeatLabel=labels[bestFeat]
     # 初始化myTree
    myTree={bestFeatLabel:{}}
     # 注:labels列表是可变对象,在PYTHON函数中作为参数时传址引用,能够被全局修改
    # 所以这行代码导致函数外的同名变量被删除了元素,造成例句无法执行,提示'no surfacing' is not in list
    del(labels[bestFeat])
    # 取出最优列,然后它的branch做分类
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
         # 求出剩余的标签label
        subLabels=labels[:]
         # 遍历当前选择特征包含的所有属性值,在每个数据集划分上递归调用函数createTree()
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree
def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']
    return  dataSet,labels
#测试算法:使用决策树执行分类
def classify(inputTree,featLabels,testVec):
    # 获取tree的根节点对于的key值
    #根据python3的特性进行修改
    firstStr1=list(inputTree.keys())
    firstStr=firstStr1[0]
     # 通过key得到根节点对应的value
    secondDict=inputTree[firstStr]
    # 判断根节点名称获取根节点在label中的先后顺序,这样就知道输入的testVec怎么开始对照树来做分类
    featIndex=featLabels.index(firstStr)
    # 测试数据,找到根节点对应的label位置,也就知道从输入的数据的第几位来开始分类
    for key in secondDict.keys():
        if testVec[featIndex]==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else: classLabel=secondDict[key]
    return classLabel
myDat,labels=createDataSet()
myTree=treePlotter.retrieveTree(0)
#print(classify(myTree,labels,[1,0]))
#print(classify(myTree,labels,[1,1]))
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr=open(filename,'rb')
    return pickle.load(fr)
#print(storeTree(myTree,'classifierStorage.txt'))
#print(grabTree('classifierStorage.txt'))
fr=open('lenses.txt')
lenses=[inst.strip().split('\t')for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)