之前对决策树的算法原理做了总结,今天就从实践的角度来介绍决策树算法,主要是讲解使用scikit-learn来跑决策树算法,结果的可视化以及一些参数调参的关键点。

这里直接使用实例进行简单的决策树分类讲解,至于sklearn的DT类详细讲解可以参看官方文档

scikit-learn决策树算法类库内部实现是使用了调优过的CART树算法,既可以做分类,又可以做回归。分类决策树的类对应的是DecisionTreeClassifier,而回归决策树的类对应的是DecisionTreeRegressor。两者的参数定义几乎完全相同,但是意义不全相同。 

使用决策树对鸢尾花数据集(iris)进行分类
  • 分析iris数据集

下面将结合Scikit-learn官网的逻辑回归模型分析鸢尾花示例,给大家进行详细讲解及拓展。由于该数据集分类标签划分为3类(0类、1类、2类),很好的适用于逻辑回归模型。

1. 鸢尾花数据集

在Sklearn机器学习包中,集成了各种各样的数据集,包括前面的糖尿病数据集,这里引入的是鸢尾花卉(Iris)数据集,它是很常用的一个数据集。鸢尾花有三个亚属,分别是山鸢尾(Iris-setosa)、变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。

该数据集一共包含4个特征变量,1个类别变量。共有150个样本,iris是鸢尾植物,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类。如表17.2所示:

scikit-learn决策树算法使用_分类

iris里有两个属性iris.data,iris.target。data是一个矩阵,每一列代表了萼片或花瓣的长宽,一共4列,每一行代表某个被测量的鸢尾植物,一共采样了150条记录。

from sklearn.datasets import load_iris   #导入数据集iris
iris = load_iris()  #载入数据集
print iris.data
#[n_samples,n_features]

输出如下所示:

[[ 5.1  3.5  1.4  0.2]
 [ 4.9  3.   1.4  0.2]
 [ 4.7  3.2  1.3  0.2]
 [ 4.6  3.1  1.5  0.2]
 ....
 [ 6.7  3.   5.2  2.3]
 [ 6.3  2.5  5.   1.9]
 [ 6.5  3.   5.2  2. ]
 [ 6.2  3.4  5.4  2.3]
 [ 5.9  3.   5.1  1.8]]

target是一个数组,存储了data中每条记录属于哪一类鸢尾植物,所以数组的长度是150,数组元素的值因为共有3类鸢尾植物,所以不同值只有3个。种类为山鸢尾、杂色鸢尾、维吉尼亚鸢尾。

scikit-learn决策树算法使用_sklearn_02

从输出结果可以看到,类标共分为三类,前面50个类标位0,中间50个类标位1,后面为2。下面给详细介绍使用决策树进行对这个数据集进行测试的代码。

2  决策树分类 

 (1)

#导入模块
import numpy as np
import matplotlib.pyplot as plt


from sklearn import tree
from sklearn.datasets import load_iris

 (2)

#载入iris数据集
iris = load_iris()

#选用第一个和第三个特征作为X
X = iris.data[:,[0,2]]

#选用target作为label
y = iris.target

#设定最大深度为4 的分类决策树
clf = tree.DecisionTreeClassifier(max_depth=4)

#拟合数据
clf = clf.fit(X,y)

(3) 

#提取特征的min和max
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1

#一维数组np.meshgrid生成网格点坐标矩阵xx和yy
#第一列花萼长度数据按h取等分作为行,并复制多行得到xx网格矩阵
#再把第二列花萼宽度数据按h取等分,作为列,并复制多列得到yy网格矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
                     np.arange(y_min, y_max, 0.1))

#调用ravel()函数将xx和yy的两个矩阵转变成一维数组
#调用np.c_[]函数组合成一个二维数组进行预测
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

#调用reshape()函数修改形状,将其Z转换为两个特征(长度和宽度)
Z = Z.reshape(xx.shape)

#plt.contourf绘制等高线
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()

最后输出结果:

scikit-learn决策树算法使用_鸢尾花数据集_03

这里对程序的讲解不是很细致,如果想了解更多可以参考分析鸢尾花数据集 

3  决策树可视化

决策树可视化化可以方便我们直观的观察模型,以及发现模型中的问题。 

from IPython.display import Image  
from sklearn import tree
import pydotplus 


dot_data = tree.export_graphviz(clf, out_file=None, 
                        class_names=iris.target_names,  
                        filled=True, rounded=True,  
                        special_characters=True)  
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png())

可以直接在notebook中看到可视化结果,如果想了解更多可视化方式参考scikit-learn决策树算法类库使用小结

scikit-learn决策树算法使用_鸢尾花数据集_04

 

 

 

 

 

参考文章:

scikit-learn决策树算法类库使用小结

分析鸢尾花数据集

sklearn decision tree 官方文档