文章目录
- 一、支持向量机的原理
- 解决的问题:
- 线性分类及其约束条件:
- 二、实战
- 2.1、线性回归
- 2.2、支持向量机SVM
- 2.3、多项式特征
一、支持向量机的原理
- Support Vector Machine。支持向量机,其含义是通过支持向量运算的分类器。其中“机”的意思是机器,可以理解为分类器。
- 那么什么是支持向量呢?在求解的过程中,会发现只根据部分数据就可以确定分类器,这些数据称为支持向量。
- 见下图,在一个二维环境中,其中点R,S,G点和其它靠近中间黑线的点可以看作为支持向量,它们可以决定分类器,也就是黑线的具体参数。
解决的问题:
- 线性分类
在训练数据中,每个数据都有n个的属性和一个二类类别标志,我们可以认为这些数据在一个n维空间里。我们的目标是找到一个n-1维的超平面(hyperplane),这个超平面可以将数据分成两部分,每部分数据都属于同一个类别。
其实这样的超平面有很多,我们要找到一个最佳的。因此,增加一个约束条件:这个超平面到每边最近数据点的距离是最大的。也成为最大间隔超平面(maximum-margin hyperplane)。这个分类器也成为最大间隔分类器(maximum-margin classifier)。
支持向量机是一个二类分类器。 - 非线性分类
SVM的一个优势是支持非线性分类。它结合使用拉格朗日乘子法和KKT条件,以及核函数可以产生非线性分类器。
线性分类及其约束条件:
SVM的解决问题的思路是找到离超平面的最近点,通过其约束条件求出最优解。
二、实战
2.1、线性回归
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,5,30).reshape(-1,1) #转化为二维的数据
y = (X - 3)**2 + 3*X - 10 #假设一个函数
#引入线性回归函数
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X,y) #学习上面的函数
#加入测试集
X_test = np.linspace(-5,5,130).reshape(-1,1)
y_ = lr.predict(X_test) #预测的结果
#数据可视化
plt.scatter(X,y) #原函数散点图图像
plt.plot(X_test,y_,c = 'r') #回归线
plt.show()
结果分析: 从上面的图中可以看出,线性的拟合并不理想,因为原函数不属于线性分布模型。
2.2、支持向量机SVM
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,10,30).reshape(-1,1) #转化为二维的数据
y = X**3 + X**2 + X + 10 #假设一个函数
#引入支持向量机
from sklearn.svm import SVR
svr = SVR(kernel='poly',degree=3) #degree度
svr.fit(X,y) #学习上面的函数
#加入测试集预测
X_test = np.linspace(-10,20,300).reshape(-1,1)
y_ = svr.predict(X_test) #预测的结果
#数据可视化
plt.scatter(X,y) #原函数散点图图像
plt.plot(X_test,y_,c = 'r') #回归线
plt.show()
结果分析: 从图中可以看出,对于曲线的预测评估,SVM的准确度比线性回归好很多。
2.3、多项式特征
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.linspace(-5,10,30).reshape(-1,1) #转化为二维的数据
y = X**3 + X**2 + X + 10 #假设一个函数
# 数据清洗
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=3)
X3 = poly.fit_transform(X)
#引入线性回归函数
from sklearn.linear_model import LinearRegression
lr = LinearRegression()
lr.fit(X3,y) #学习模型
#加入测试集预测
X_test = np.linspace(-10,20,300).reshape(-1,1)
X_test3 = poly.fit_transform(X_test) #需要统一数据维度
y_ = lr.predict(X_test3) #预测
plt.scatter(X,y)
plt.plot(X_test,y_,color = 'r')
plt.show()
结果分析: 如上图所示,它的的数据拟合度几乎完美了!