前言
通过验证一个学习器在训练集和测试集上的表现,来确定模型是否合适,参数是否合适。
如果训练集和测试集得分都很低,说明学习器不合适。
如果训练集得分高,测试集得分低,模型过拟合,训练集得分低,测试集得分高,不太可能。
示例代码
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load\_digits
from sklearn.model\_selection import validation\_curve
from sklearn.svm import SVC
# 加载数据
digits \= load\_digits()
X, y \= digits.data, digits.target
# 验证曲线
param\_range \= np.logspace(-6, -1, 5)
train\_scores, test\_scores \= validation\_curve(
SVC(), X, y, param\_name\="gamma", param\_range=param\_range,
cv\=10, scoring="accuracy", n\_jobs=1)
train\_scores\_mean \= np.mean(train\_scores, axis=1)
train\_scores\_std \= np.std(train\_scores, axis=1)
test\_scores\_mean \= np.mean(test\_scores, axis=1)
test\_scores\_std \= np.std(test\_scores, axis=1)
plt.title("SVM VC")
plt.xlabel("$\\gamma$")
plt.ylabel("Score")
plt.ylim(0.0, 1.1)
# 训练数据
plt.semilogx(param\_range, train\_scores\_mean, label\="train score", color="r")
plt.fill\_between(param\_range, train\_scores\_mean \- train\_scores\_std,
train\_scores\_mean \+ train\_scores\_std, alpha=0.2, color="r")
# 测试数据
plt.semilogx(param\_range, test\_scores\_mean, label\="test score",color="g")
plt.fill\_between(param\_range, test\_scores\_mean \- test\_scores\_std,
test\_scores\_mean \+ test\_scores\_std, alpha=0.2, color="g")
plt.legend(loc\="best")
plt.show()
输出
参数gamma的调节
很小时,训练集和测试集得分都低,欠拟合
增大时,训练集和测试集得分有个很好地值
过大时,训练集得分高,测试集得分低,过拟合。