之前对决策树的算法原理做了总结,今天就从实践的角度来介绍决策树算法,主要是讲解使用scikit-learn来跑决策树算法,结果的可视化以及一些参数调参的关键点。
这里直接使用实例进行简单的决策树分类讲解,至于sklearn的DT类详细讲解可以参看官方文档
scikit-learn决策树算法类库内部实现是使用了调优过的CART树算法,既可以做分类,又可以做回归。分类决策树的类对应的是DecisionTreeClassifier,而回归决策树的类对应的是DecisionTreeRegressor。两者的参数定义几乎完全相同,但是意义不全相同。
使用决策树对鸢尾花数据集(iris)进行分类-
分析iris数据集
下面将结合Scikit-learn官网的逻辑回归模型分析鸢尾花示例,给大家进行详细讲解及拓展。由于该数据集分类标签划分为3类(0类、1类、2类),很好的适用于逻辑回归模型。
1. 鸢尾花数据集
在Sklearn机器学习包中,集成了各种各样的数据集,包括前面的糖尿病数据集,这里引入的是鸢尾花卉(Iris)数据集,它是很常用的一个数据集。鸢尾花有三个亚属,分别是山鸢尾(Iris-setosa)、变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。
该数据集一共包含4个特征变量,1个类别变量。共有150个样本,iris是鸢尾植物,这里存储了其萼片和花瓣的长宽,共4个属性,鸢尾植物分三类。如表17.2所示:
iris里有两个属性iris.data,iris.target。data是一个矩阵,每一列代表了萼片或花瓣的长宽,一共4列,每一行代表某个被测量的鸢尾植物,一共采样了150条记录。
from sklearn.datasets import load_iris #导入数据集iris
iris = load_iris() #载入数据集
print iris.data
#[n_samples,n_features]
输出如下所示:
[[ 5.1 3.5 1.4 0.2]
[ 4.9 3. 1.4 0.2]
[ 4.7 3.2 1.3 0.2]
[ 4.6 3.1 1.5 0.2]
....
[ 6.7 3. 5.2 2.3]
[ 6.3 2.5 5. 1.9]
[ 6.5 3. 5.2 2. ]
[ 6.2 3.4 5.4 2.3]
[ 5.9 3. 5.1 1.8]]
target是一个数组,存储了data中每条记录属于哪一类鸢尾植物,所以数组的长度是150,数组元素的值因为共有3类鸢尾植物,所以不同值只有3个。种类为山鸢尾、杂色鸢尾、维吉尼亚鸢尾。
从输出结果可以看到,类标共分为三类,前面50个类标位0,中间50个类标位1,后面为2。下面给详细介绍使用决策树进行对这个数据集进行测试的代码。
2 决策树分类
(1)
#导入模块
import numpy as np
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.datasets import load_iris
(2)
#载入iris数据集
iris = load_iris()
#选用第一个和第三个特征作为X
X = iris.data[:,[0,2]]
#选用target作为label
y = iris.target
#设定最大深度为4 的分类决策树
clf = tree.DecisionTreeClassifier(max_depth=4)
#拟合数据
clf = clf.fit(X,y)
(3)
#提取特征的min和max
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
#一维数组np.meshgrid生成网格点坐标矩阵xx和yy
#第一列花萼长度数据按h取等分作为行,并复制多行得到xx网格矩阵
#再把第二列花萼宽度数据按h取等分,作为列,并复制多列得到yy网格矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
#调用ravel()函数将xx和yy的两个矩阵转变成一维数组
#调用np.c_[]函数组合成一个二维数组进行预测
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
#调用reshape()函数修改形状,将其Z转换为两个特征(长度和宽度)
Z = Z.reshape(xx.shape)
#plt.contourf绘制等高线
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
plt.show()
最后输出结果:
这里对程序的讲解不是很细致,如果想了解更多可以参考分析鸢尾花数据集
3 决策树可视化
决策树可视化化可以方便我们直观的观察模型,以及发现模型中的问题。
from IPython.display import Image
from sklearn import tree
import pydotplus
dot_data = tree.export_graphviz(clf, out_file=None,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
可以直接在notebook中看到可视化结果,如果想了解更多可视化方式参考scikit-learn决策树算法类库使用小结
参考文章: