多叉分类树

下面实现的分类树只限于特征是离散变量,而连续变量不能处理。另外,西瓜书介绍的缺失值的处理多变量处理均未实现。下面实现的树有一个共同的特点,它的分支依据都是一个具体的特征取值,且每次特征选择之后都要删除特征

一、python实现

我使用python的类实现多分叉决策树,包括决策树的训练和预测两部分。

1.1树的结构

使用python的字典(dict)作为树的结点,字典的嵌套形成树,格式如下

{'#':feature_name,'feature_value':{}}   #树的结点
#特征名字为0,取值为0的分支
{'#': 0, 0: 0, 1: {'#': 1, 0: 0, 1: 1}} #例子
{'#':feature_name,'feature_value':{}}   #树的结点
#特征名字为0,取值为0的分支
{'#': 0, 0: 0, 1: {'#': 1, 0: 0, 1: 1}} #例子

1.2 种树

1.2.1 种树流程

建树的过程就是迭代选择划分的特征,每一次迭代选择一个特征进行划分。决策树的训练一般遵循以下两个步骤:

  1. 特征选择
  2. 进入下一次递归(给子集进行特征选择)

其中,迭代返回的情况有

  • 类别值都一样,返回该类别
  • 特征值都一样,返回类别频数最大的哪一类

1.2.2 特征选择指标

特征选择就是选择“纯度”(混乱程度越低)最大的特征。前面提到,信息增益信息增益率基尼指数都可以用于特征选择。接下来根据它们的公式,可以依次写出相应的函数,用于选择纯度最大的特征。

  • 权值
    下面的公式中的$p_k$(概率)或者$\frac{|D^v|}{|D|}$(权值)都可以用这个公式计算。其中注意的是两个参数都是数组类型。
def cal_weight(y,w=None):
    '''计算离散变量的权值\概率
    :param y: 数组,arr
    :param w: 样本权值,arr
    :return:
    '''
    unique_val = set(y) #用数组还是字典存储结果?用生成器
    if w is None:
        m = len(y)
        for v in unique_val:
            yield v,sum(v==y)/m   #用生成器返回结果:取值,权值\概率
    else:
        sum_ = sum(w)
        for v in unique_val:
            yield v,sum(w[y==v])/sum_     #用生成器返回结果:取值,权值\概率
    yield None,0 #y为空的情况
def cal_weight(y,w=None):
    '''计算离散变量的权值\概率
    :param y: 数组,arr
    :param w: 样本权值,arr
    :return:
    '''
    unique_val = set(y) #用数组还是字典存储结果?用生成器
    if w is None:
        m = len(y)
        for v in unique_val:
            yield v,sum(v==y)/m   #用生成器返回结果:取值,权值\概率
    else:
        sum_ = sum(w)
        for v in unique_val:
            yield v,sum(w[y==v])/sum_     #用生成器返回结果:取值,权值\概率
    yield None,0 #y为空的情况
  • 信息熵
    这里的信息熵不直接作为特征选择指标,而是作为信息增益的一部分

$$
Ent(D)=-\sum_{k=1}^np_k\log_2{p_k}
$$

# 计算信息熵
def Ent(y,w): #计算信息熵只需要用到数据集D中的因变量y
    '''
    :param y:因变量y,shpae =(m);arr类型
    :param w: 样本权值,arr
    :return:
    '''
    ent = 0
    for v,p in cal_weight(y,w):
        ent -= p*np.log2(p)
    return ent
# 计算信息熵
def Ent(y,w): #计算信息熵只需要用到数据集D中的因变量y
    '''
    :param y:因变量y,shpae =(m);arr类型
    :param w: 样本权值,arr
    :return:
    '''
    ent = 0
    for v,p in cal_weight(y,w):
        ent -= p*np.log2(p)
    return ent
  • 信息增益(ID3)

$$
Gain(D,x_i)=Ent(D)-\sum_{i=1}^v\frac{|D^v|}{|D|}Ent(D^v)\
其中|D^v|是所有取值为v的样本数量
$$

def Gain(x_i,y,ent,w):
    '''
    :param x_i:第i个特征(属性),1*m
    :param w: 样本权值,arr
    :return:
    '''
    gain = ent  #信息增益
    for v,p in cal_weight(x_i,w):
        index = x_i == v    #取值为v的索引
        w_ = w if w is None else w[index]
        gain -= p**Ent(y[index],w_)
    return gain
def Gain(x_i,y,ent,w):
    '''
    :param x_i:第i个特征(属性),1*m
    :param w: 样本权值,arr
    :return:
    '''
    gain = ent  #信息增益
    for v,p in cal_weight(x_i,w):
        index = x_i == v    #取值为v的索引
        w_ = w if w is None else w[index]
        gain -= p**Ent(y[index],w_)
    return gain
  • 信息增益率(C4.5)

$$
Gain_radio(D,x_i)=\frac{Gain(D,x_i)}{IV(x_i)}\
其中属性x_i的“固有值”\
IV=-\sum_i^v\frac{|D_v|}{|D|}\log_2\frac{|D_v|}{|D|}
$$

#第i个特征的信息增益率
def Gain_Radio(x_i,y,ent,w ):
    '''
    :param x_i:第i个特征(属性),1*m
    :return:
    '''
    gain = ent  #信息增益
    iv = 1e-9  #固有值,平滑处理
    for v,p in cal_weight(x_i,w):
        index = x_i == v  # 取值为v的索引
        w_ = w if w is None else w[index]
        gain -= p**Ent(y[index],w_)
        iv -=p*np.log2(p)
    return gain/iv
#第i个特征的信息增益率
def Gain_Radio(x_i,y,ent,w ):
    '''
    :param x_i:第i个特征(属性),1*m
    :return:
    '''
    gain = ent  #信息增益
    iv = 1e-9  #固有值,平滑处理
    for v,p in cal_weight(x_i,w):
        index = x_i == v  # 取值为v的索引
        w_ = w if w is None else w[index]
        gain -= p**Ent(y[index],w_)
        iv -=p*np.log2(p)
    return gain/iv
  • Gini(基尼值)
    基尼值也不直接作为特征选择指标,而是作为基尼指数的一部分
#第i个特征的基尼值
def Gini(y,w):
    p_2 = 0
    for v,p in cal_weight(y,w):
        p_2 += p**2
    return 1- p_2
#第i个特征的基尼值
def Gini(y,w):
    p_2 = 0
    for v,p in cal_weight(y,w):
        p_2 += p**2
    return 1- p_2
  • 基尼指数
#第i个特征的基尼指数
def Gini_index(x_i,y,w):
    gini_index = 0
    for v,p in cal_weight(x_i,w):
        index = x_i == v  # 取值为v的索引
        w_ = w if w is None else w[index]
        gini_index += p**Gini(y[index],w_)
    return gini_index
#第i个特征的基尼指数
def Gini_index(x_i,y,w):
    gini_index = 0
    for v,p in cal_weight(x_i,w):
        index = x_i == v  # 取值为v的索引
        w_ = w if w is None else w[index]
        gini_index += p**Gini(y[index],w_)
    return gini_index

1.2.3 生成树(种树)

下面是决策树的整体结构。接下来解释构造函数三个参数的作用:

  • criterion:选择特征选择方法
  • splitter:选择是否随机特征选择
  • weight:样本权重

其中splitter、weight有何作用?答案是用来种森林。

若splitter选择'random',可以用来写ExtraTree(极度随机森林)

若指定weight,可以用来写AdaBoost(...森林)

#多叉分类树
class ClassifyTree_:
    def __init__(self,criterion="gini",splitter='best',weight=None):
        self.criterion = criterion
        self.weight = weight
        self.splitter = splitter

#----------特征选择方法-----------------
    def id3(self,X,y,weight):       #criterion="id3",splitter='best'
    def c45(self,X,y,weight):       #criterion="C45",splitter='best' 
    def gini(self,X,y,weight):      #criterion="gini",splitter='best'
    def rand_(self,X,y,weight):     #splitter='random'              

#----------种树-------------------------
    def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
    def fit(self, X, y,weight=None):
        # 四种不同的树
        self.weight = weight
        if self.splitter == 'best':
            if self.criterion == 'id3':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
            elif self.criterion == 'c45':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.c45, weight)
            elif self.criterion == 'gini':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.cart, weight)
            else:
                raise ('gini/c45/id3')
        else:
            self.tree = self.build_(X, y, list(range(X.shape[1])), self.rand_, weight)
        return self
#----------预测-------------------------
    def predict(self, X):
#多叉分类树
class ClassifyTree_:
    def __init__(self,criterion="gini",splitter='best',weight=None):
        self.criterion = criterion
        self.weight = weight
        self.splitter = splitter

#----------特征选择方法-----------------
    def id3(self,X,y,weight):       #criterion="id3",splitter='best'
    def c45(self,X,y,weight):       #criterion="C45",splitter='best' 
    def gini(self,X,y,weight):      #criterion="gini",splitter='best'
    def rand_(self,X,y,weight):     #splitter='random'              

#----------种树-------------------------
    def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
    def fit(self, X, y,weight=None):
        # 四种不同的树
        self.weight = weight
        if self.splitter == 'best':
            if self.criterion == 'id3':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
            elif self.criterion == 'c45':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.c45, weight)
            elif self.criterion == 'gini':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.cart, weight)
            else:
                raise ('gini/c45/id3')
        else:
            self.tree = self.build_(X, y, list(range(X.shape[1])), self.rand_, weight)
        return self
#----------预测-------------------------
    def predict(self, X):

1.3 例子

下面实现的分类树只限于特征是离散变量,而连续变量不能处理。另外,西瓜书介绍的缺失值的处理多变量处理均未实现。阅读这些例子可以轻松理解上面的建树流程。注意,下面的例子都是简易版本的决策树,而非完整版。

1.3.1 ID3决策树

  • 使用信息增益划分数据集
# 使用id3拿到最佳特征的索引
    def id3(self,X,y,weight):
        best_Index = -1
        best_gain = -np.inf
        ent = Ent(y,self.weight)
        for i in range(X.shape[1]):
            gain = Gain(X[:,i],y,ent,weight)
            if gain > best_gain:
                best_gain = gain
                best_Index = i  #信息增益最大的特征
        return best_Index
        
    # 使用id3拿到最佳特征的索引
    def id3(self,X,y,weight):
        best_Index = -1
        best_gain = -np.inf
        ent = Ent(y,self.weight)
        for i in range(X.shape[1]):
            gain = Gain(X[:,i],y,ent,weight)
            if gain > best_gain:
                best_gain = gain
                best_Index = i  #信息增益最大的特征
        return best_Index

这个建树函数需要注意的两个点:

为何要传入$feat_lst$(各个特征的名字)? 因为每次划分后,特征会被删除掉。

注意2个步骤和3个退出条件

def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
        '''
        :param X:
        :param y:
        :param feat_lst:特征名字的列表
        :return:
        '''

        m,n = X.shape   #样本,特征数量
        # if m==0: return  # 返回1:没有样本了,退出;;会出现这种情况吗?
        if len(set(y)) == 1:return y[0]  #返回2:类别值都一样
        
        # 1.特征选择
        if n == 1:
            node = {'#': feat_lst[0]}  # 结点,存储特征的索引
            x = X[:, 0]
            for val in set(x):  # 该特征所有的取值
                node[val] = cal_mode(x[x==val]) #取众数
        else:
            best_Index = criterion(X, y, weight)
            splitVal = set(X[:,best_Index])     #该特征所有的取值
            if len(splitVal)==1 :return  cal_mode(y) #返回3:特征值都一样,返回频数最大的类别
            else:
                node = {'#':feat_lst[best_Index] }     #结点,存储特征的索引
                index = list(range(n))
                index.pop(best_Index)    # 需要划分的特征index
                feat_l=feat_lst[:]  #避免影响,前面的
                feat_l.pop(best_Index)
                # 2.划分数据集,递归调用种子树
                for val in splitVal:
                    i_sample = X[:, best_Index] == val  #子数据集
                    weight_ = weight if weight is None else weight[i_sample]
                    node[val] = self.build_(X[i_sample][:, index], y[i_sample], feat_l,criterion,weight_)
        return node
    def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
        '''
        :param X:
        :param y:
        :param feat_lst:特征名字的列表
        :return:
        '''

        m,n = X.shape   #样本,特征数量
        # if m==0: return  # 返回1:没有样本了,退出;;会出现这种情况吗?
        if len(set(y)) == 1:return y[0]  #返回2:类别值都一样
        
        # 1.特征选择
        if n == 1:
            node = {'#': feat_lst[0]}  # 结点,存储特征的索引
            x = X[:, 0]
            for val in set(x):  # 该特征所有的取值
                node[val] = cal_mode(x[x==val]) #取众数
        else:
            best_Index = criterion(X, y, weight)
            splitVal = set(X[:,best_Index])     #该特征所有的取值
            if len(splitVal)==1 :return  cal_mode(y) #返回3:特征值都一样,返回频数最大的类别
            else:
                node = {'#':feat_lst[best_Index] }     #结点,存储特征的索引
                index = list(range(n))
                index.pop(best_Index)    # 需要划分的特征index
                feat_l=feat_lst[:]  #避免影响,前面的
                feat_l.pop(best_Index)
                # 2.划分数据集,递归调用种子树
                for val in splitVal:
                    i_sample = X[:, best_Index] == val  #子数据集
                    weight_ = weight if weight is None else weight[i_sample]
                    node[val] = self.build_(X[i_sample][:, index], y[i_sample], feat_l,criterion,weight_)
        return node
  • 训练的函数入口
def fit(self, X, y,weight=None):
        # 建树
        self.weight = weight    #保存样本权重
        if self.splitter == 'best':
            if self.criterion == 'id3':
                # 这里用索引来代替特征的名字  list(range(X.shape[1])):索引
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
   def fit(self, X, y,weight=None):
        # 建树
        self.weight = weight    #保存样本权重
        if self.splitter == 'best':
            if self.criterion == 'id3':
                # 这里用索引来代替特征的名字  list(range(X.shape[1])):索引
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
  • 预测函数
# 分不同数据类型进行调用;二维数组或者一个向量(样本)
    def predict(self, X):
        if len(X.shape) > 1:  # 二维数组
            rst = np.zeros(X.shape[0])#.astype(objecT),可以存放字符串
            for i,x in enumerate(X):
                rst[i] = self.predict_(x)
        elif len(X) == 0:
            rst = np.inf
        else:
            rst = self.predict_(X)
        return rst
    # 真正开始预测
    def predict_(self,x):
        tree = self.tree
        while True:
            if isinstance(tree,dict):
                key = tree['#'] #树的名字
            else:
                return tree
            try:
                tree = tree[x[key]] #根据取值进入下一级
            except:
                return np.inf
    # 分不同数据类型进行调用;二维数组或者一个向量(样本)
    def predict(self, X):
        if len(X.shape) > 1:  # 二维数组
            rst = np.zeros(X.shape[0])#.astype(objecT),可以存放字符串
            for i,x in enumerate(X):
                rst[i] = self.predict_(x)
        elif len(X) == 0:
            rst = np.inf
        else:
            rst = self.predict_(X)
        return rst
    # 真正开始预测
    def predict_(self,x):
        tree = self.tree
        while True:
            if isinstance(tree,dict):
                key = tree['#'] #树的名字
            else:
                return tree
            try:
                tree = tree[x[key]] #根据取值进入下一级
            except:
                return np.inf

ID3决策树使用选择信息增益最大的特征进行划分。稍微将特征选择的标准改变,可得C4.5决策树。在信息增益高于平均水平的特征中选择信息增益率最大的。同样地,将指标改成基尼指数,也可以得到...决策树

二、测试

2.1 可跑性测试

一般而已,当你花费九牛二虎之力终于把一颗树的代码撸完之后,都会遭到跑不动沉痛打击。所以,我们先拿简单的数据集来测试。

def valid():
    '''树能不能跑'''
    dataSet = np.array([[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']])
    X = dataSet[:,:-1]
    y = dataSet[:, -1]
    m = ClassifyTree_()
    m.fit(X, y) #训练
    print(m.predict(np.array(['1','1'])))   #预测
    return m.tree

if __name__ == '__main__':
    a = valid()
    print(a)
def valid():
    '''树能不能跑'''
    dataSet = np.array([[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']])
    X = dataSet[:,:-1]
    y = dataSet[:, -1]
    m = ClassifyTree_()
    m.fit(X, y) #训练
    print(m.predict(np.array(['1','1'])))   #预测
    return m.tree

if __name__ == '__main__':
    a = valid()
    print(a)

结果如下

Python多叉树打印路径 python 多叉树生成_特征选择

三、完整代码

下面可以通过传入不同参数选择不同的树。

import numpy as np
from utils import cal_mode,Gini_index,Ent,Gain,Gain_Radio

#多叉分类树
class ClassifyTree_:
    def __init__(self,criterion="gini",splitter='best',weight=None):
        self.criterion = criterion
        self.weight = weight
        self.splitter = splitter


    def id3(self,X,y,weight):
        best_Index = -1
        best_gain = -np.inf
        ent = Ent(y,self.weight)
        for i in range(X.shape[1]):
            gain = Gain(X[:,i],y,ent,weight)
            if gain > best_gain:
                best_gain = gain
                best_Index = i  #信息增益最大的特征
        return best_Index
    def c45(self,X,y,weight):    #这里需要传入特征列表,因为X改变了
        '''建树'''
        # 特征选择
        n = X.shape[1]
        gain_arr = np.zeros(n)  # 增益
        ent = Ent(y,self.weight)
        for i in range(n):  # 特征数量
            gain_arr[i] = Gain(X[:, i], y, ent,weight)
        m_gain = np.mean(gain_arr)  # 平均增益
        best_Index = -1
        best_gain_radio = -np.inf
        for i in range(n):  # 对每个特征
            if gain_arr[i] > m_gain:
                gain_radio = Gain_Radio(X[:, i], y, ent,weight)
                if gain_radio > best_gain_radio:
                    best_gain_radio = gain_radio
                    best_Index = i
        return best_Index
    def gini(self,X,y,weight):
        '''建树'''
        # 特征选择
        best_Index = -1
        best_gini_index = np.inf
        for i in range(X.shape[1]):
            gini_index = Gini_index(X[:, i], y,weight)
            if gini_index < best_gini_index:
                best_gini_index = gini_index
                best_Index = i  # 基尼指数最小的特征
        return best_Index
    def rand_(self,X,y,weight):
        return np.random.choice(X.shape[1])

    def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
        '''
        :param X:
        :param y:
        :param feat_lst:特征名字的列表
        :return:
        '''
        # 特征选择
        m,n = X.shape   #样本,特征数量
        # if m==0: return  # 没有样本了,退出;;;会出现这种情况吗
        if len(set(y)) == 1:return y[0]  #类别值都一样
        if n == 1:
            node = {'#': feat_lst[0]}  # 结点,存储特征的索引
            x = X[:, 0]
            for val in set(x):  # 该特征所有的取值
                node[val] = cal_mode(x[x==val]) #取众数
        else:
            best_Index = criterion(X, y, weight)
            splitVal = set(X[:,best_Index])     #该特征所有的取值
            if len(splitVal)==1 :return  cal_mode(y) #特征值都一样,返回频数最大的类别
            else:
                node = {'#':feat_lst[best_Index] }     #结点,存储特征的索引
                index = list(range(n))
                index.pop(best_Index)    # 需要划分的特征index
                feat_l=feat_lst[:]  #避免影响,前面的
                feat_l.pop(best_Index)
                for val in splitVal:
                    i_sample = X[:, best_Index] == val  #子数据集
                    weight_ = weight if weight is None else weight[i_sample]
                    node[val] = self.build_(X[i_sample][:, index], y[i_sample], feat_l,criterion,weight_)
        return node

    def fit(self, X, y,weight=None):
        # 建树
        self.weight = weight
        if self.splitter == 'best':
            if self.criterion == 'id3':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
            elif self.criterion == 'c45':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.c45, weight)
            elif self.criterion == 'gini':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.gini, weight)
            else:
                raise ('gini/c45/id3')
        else:
            self.tree = self.build_(X, y, list(range(X.shape[1])), self.rand_, weight)
        return self

    # 分不同数据类型进行调用;二维数组或者一个向量(样本)
    def predict(self, X):
        if len(X.shape) > 1:  # 二维数组
            rst = np.zeros(X.shape[0])#.astype(objecT),可以存放字符串
            for i,x in enumerate(X):
                rst[i] = self.predict_(x)
        elif len(X) == 0:
            rst = np.inf
        else:
            rst = self.predict_(X)
        return rst
    # 真正开始预测
    def predict_(self,x):
        tree = self.tree
        while True:
            if isinstance(tree,dict):
                key = tree['#'] #树的名字
            else:
                return tree
            try:
                tree = tree[x[key]] #根据取值进入下一级
            except:
                return np.inf

def valid():
    '''树能不能跑'''
    dataSet = np.array([[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']])
    X = dataSet[:,:-1]
    y = dataSet[:, -1]
    m = ClassifyTree_()
    m.fit(X, y) #训练
    print('预测结果',m.predict(np.array(['1','1'])))   #预测
    return m.tree

if __name__ == '__main__':
    a = valid()
    print('训练出来的树:',a)

import numpy as np
from utils import cal_mode,Gini_index,Ent,Gain,Gain_Radio

#多叉分类树
class ClassifyTree_:
    def __init__(self,criterion="gini",splitter='best',weight=None):
        self.criterion = criterion
        self.weight = weight
        self.splitter = splitter


    def id3(self,X,y,weight):
        best_Index = -1
        best_gain = -np.inf
        ent = Ent(y,self.weight)
        for i in range(X.shape[1]):
            gain = Gain(X[:,i],y,ent,weight)
            if gain > best_gain:
                best_gain = gain
                best_Index = i  #信息增益最大的特征
        return best_Index
    def c45(self,X,y,weight):    #这里需要传入特征列表,因为X改变了
        '''建树'''
        # 特征选择
        n = X.shape[1]
        gain_arr = np.zeros(n)  # 增益
        ent = Ent(y,self.weight)
        for i in range(n):  # 特征数量
            gain_arr[i] = Gain(X[:, i], y, ent,weight)
        m_gain = np.mean(gain_arr)  # 平均增益
        best_Index = -1
        best_gain_radio = -np.inf
        for i in range(n):  # 对每个特征
            if gain_arr[i] > m_gain:
                gain_radio = Gain_Radio(X[:, i], y, ent,weight)
                if gain_radio > best_gain_radio:
                    best_gain_radio = gain_radio
                    best_Index = i
        return best_Index
    def gini(self,X,y,weight):
        '''建树'''
        # 特征选择
        best_Index = -1
        best_gini_index = np.inf
        for i in range(X.shape[1]):
            gini_index = Gini_index(X[:, i], y,weight)
            if gini_index < best_gini_index:
                best_gini_index = gini_index
                best_Index = i  # 基尼指数最小的特征
        return best_Index
    def rand_(self,X,y,weight):
        return np.random.choice(X.shape[1])

    def build_(self,X,y,feat_lst,criterion,weight=None):    #这里需要传入特征列表,因为X改变了
        '''
        :param X:
        :param y:
        :param feat_lst:特征名字的列表
        :return:
        '''
        # 特征选择
        m,n = X.shape   #样本,特征数量
        # if m==0: return  # 没有样本了,退出;;;会出现这种情况吗
        if len(set(y)) == 1:return y[0]  #类别值都一样
        if n == 1:
            node = {'#': feat_lst[0]}  # 结点,存储特征的索引
            x = X[:, 0]
            for val in set(x):  # 该特征所有的取值
                node[val] = cal_mode(x[x==val]) #取众数
        else:
            best_Index = criterion(X, y, weight)
            splitVal = set(X[:,best_Index])     #该特征所有的取值
            if len(splitVal)==1 :return  cal_mode(y) #特征值都一样,返回频数最大的类别
            else:
                node = {'#':feat_lst[best_Index] }     #结点,存储特征的索引
                index = list(range(n))
                index.pop(best_Index)    # 需要划分的特征index
                feat_l=feat_lst[:]  #避免影响,前面的
                feat_l.pop(best_Index)
                for val in splitVal:
                    i_sample = X[:, best_Index] == val  #子数据集
                    weight_ = weight if weight is None else weight[i_sample]
                    node[val] = self.build_(X[i_sample][:, index], y[i_sample], feat_l,criterion,weight_)
        return node

    def fit(self, X, y,weight=None):
        # 建树
        self.weight = weight
        if self.splitter == 'best':
            if self.criterion == 'id3':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.id3, weight)
            elif self.criterion == 'c45':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.c45, weight)
            elif self.criterion == 'gini':
                self.tree = self.build_(X, y, list(range(X.shape[1])), self.gini, weight)
            else:
                raise ('gini/c45/id3')
        else:
            self.tree = self.build_(X, y, list(range(X.shape[1])), self.rand_, weight)
        return self

    # 分不同数据类型进行调用;二维数组或者一个向量(样本)
    def predict(self, X):
        if len(X.shape) > 1:  # 二维数组
            rst = np.zeros(X.shape[0])#.astype(objecT),可以存放字符串
            for i,x in enumerate(X):
                rst[i] = self.predict_(x)
        elif len(X) == 0:
            rst = np.inf
        else:
            rst = self.predict_(X)
        return rst
    # 真正开始预测
    def predict_(self,x):
        tree = self.tree
        while True:
            if isinstance(tree,dict):
                key = tree['#'] #树的名字
            else:
                return tree
            try:
                tree = tree[x[key]] #根据取值进入下一级
            except:
                return np.inf

def valid():
    '''树能不能跑'''
    dataSet = np.array([[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']])
    X = dataSet[:,:-1]
    y = dataSet[:, -1]
    m = ClassifyTree_()
    m.fit(X, y) #训练
    print('预测结果',m.predict(np.array(['1','1'])))   #预测
    return m.tree

if __name__ == '__main__':
    a = valid()
    print('训练出来的树:',a)