系列文章:

scikit-learn小白入门教程(一)

本篇文章,将会带你动手训练一个模型。


文章目录

  • 一、Scikit-learn简介
  • 二、加载数据集
  • 三、学习和预测


一、Scikit-learn简介

SciKit learn的简称是SKlearn,是一个python库,专门用于机器学习的模块。
以下是它的官方网站,文档等资源都可以在里面找到
http://scikit-learn.org/stable/#。

SKlearn包含的机器学习方式:
分类,回归,无监督,数据降维,数据预处理等等,包含了常见的大部分机器学习方法。

关于SKlearn的安装,网上教程很多,不在此赘述。建议使用Anaconda,可以方便的安装各种库。

二、加载数据集

说了这么多,我们就来加载一个数据集来玩玩吧。scikit-learn库中就附了一些数据集,我们可以直接利用这些数据集来实践一下。

from sklearn import datasets
iris = datasets.load_iris()
digits = datasets.load_digits()

其中,iris数据是鸢尾花卉数据集,digits是手写数字数据集。而我们这次用到的数据是digits,任务是利用给出的图像数据,来预测图像上的数字是什么数字。

sklearn安装的镜像 sklearn安装教程_人工智能


digits数据集是一个类似于字典的对象,我们分析所需要用到的数据在.data成员中,该成员是一个数组。

查看digits.data,可以看到一个数组就代表了一张图像的若干特征:

sklearn安装的镜像 sklearn安装教程_python_02


而digits.target是数据集的结果,表示也就是我们所需要预测的特征——手写数字代表的数字是几?下面来查看一下digits.target:

sklearn安装的镜像 sklearn安装教程_python_03

可以看到,结果是一个一维数组,表示第1个数字是0,第2个数字是1,第3个数字是2······

三、学习和预测

接下来,将是我们的重头戏!将介绍如何利用机器学习模型来学习和预测。

对于digits数据集,我们的任务是根据给定的手写数字图像,预测其代表的数字。

在scikit-learn中,我们常用fit(X, y)和predict(T)来作为分类评估器的实现方法(不知道有什么用?不急,我们后面会讲)。这次我们使用的分类评估器是sklearn.svm.SVC,即支持向量机。

首先,先来创建一个分类器吧!

from sklearn import svm

clf=svm.SVC(gamma=0.001,C=100.)

此处,我们实例化了一个svm.SVC对象,gamma和C是这个模型的参数。如果你经常看机器学习的项目,那么你肯定会经常看到clf这个词,这个词表示classifier(分类器)。

现在,我们就创建好了一个clf,接下来我们就可以用数据集去训练它了。这时我们就要用到前面的fit()方法了。我们需要将训练集通过fit方法传递给clf,这样就可以训练clf啦!

clf.fit(digits.data[:-1],digits.target[:-1])

注意,上面的 [:-1]表示我们只用前面的n-1个样本来训练,最后一个样本我们要保留作为测试集。

当我们使用了fit后,我们的模型就已经训练好了。注意,我们训练的其实是模型的参数,训练好后的参数如下图所示。

sklearn安装的镜像 sklearn安装教程_机器学习_04

嘿嘿,现在clf训练好以后,就可以开始预测啦!这时我们要用到另一个方法predict(),我们只需要把测试集的特征输入,就可以得到预测结果。

clf.predict(digits.data[-1:])

预测结果如下图所示:

sklearn安装的镜像 sklearn安装教程_python_05

可以看到,我们的clf预测的结果是8。我们来看看我们预测的这个图像长什么样子呢:

sklearn安装的镜像 sklearn安装教程_机器学习_06

呃…好吧,是不是看不出来?感觉像是6又像是8.毕竟,图像的分辨率很差,再加上图像识别本就是一个较难的问题,如果用这几行代码就能准确预测,那我们还要花这么多时间去研究图像识别干嘛呢?对吧。

虽然,我们预测的结果可能不是很好,但是我们总算是迈出了第一步!我们自己构建并训练了一个机器学习模型,并且用它做了一些预测。怎么样,看上去也不是很难吧!