本篇的主要内容:
- 介绍Scipy中optimize模块的leastsq函数
最近接触到了Scipy中optimize模块的一些函数,optimize模块中提供了很多数值优化算法,其中,最小二乘法可以说是最经典的数值优化技术了, 通过最小化误差的平方来寻找最符合数据的曲线。在optimize模块中,使用leastsq()函数可以很快速地使用最小二乘法对数据进行拟合。
首先来看leastsq()函数地调用格式:
leastsq(func,
x0,
args=(),
Dfun=None,
full_output=0,
col_deriv=0,
ftol=1.49012e-08,
xtol=1.49012e-08,
gtol=0.0,
maxfev=0,
epsfcn=0.0,
factor=100,
diag=None,
warning=True)
参数还是非常多的,一般来说,我们只需要前三个参数就够了他们的作用分别是:
- func:误差函数
- x0:表示函数的参数
- args()表示数据点
举个例子:
这里要进行拟合的数据点都分布在这条正弦曲线附近:
def func(x):
return 2*np.sin(2*np.pi*x)
然后定义误差函数,所谓误差就是指我们拟合的曲线的值对应真实值的差:
def residuals(p, x, y):
fun = np.poly1d(p) # poly1d()函数可以按照输入的列表p返回一个多项式函数
return y - fun(x) # 返回真实值 与我们拟合的曲线上对应的值的差
这里设计了一个poly1d()函数,关于这个函数,简单理解下就是输入一个列表,返回以这个列表中的值为参数的多项式,例如:
输入:[1,2,3]
返回:x^2 + 2x + 3
多项式的次数是从0开始记的,要注意这个地方
下面定义关于拟合的曲线的函数:
# 拟合函数
def fitting(p):
pars = np.random.rand(p+1) # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p
r = leastsq(residuals, pars, args=(X, Y)) # 三个参数:误差函数、函数参数列表、数据点
return r
注释里的内容就是要注意的地方,由于会多次调用拟合,多以写成了函数的形式,这里传入的p是一个数字,表示我们想要得到拟合曲线的次数,比如我想针对这些数据点得到一条3次的曲线,就调用p=3类似,注意这里leastsq()函数的返回值,这里的返回值保存的是拟合的曲线的信息,如果打印这里的r,就会发现返回了一个truple,其中第一维是一个列表,保存的是拟合的曲线的参数,所以要注意如何获得这些参数。
接下来定义一下我们要进行拟合的数据点,这里定义了10个:
# 要进行拟合的数据点
X = np.linspace(0, 1, 10)
Y = [np.random.normal(0, 0.1)+num for num in func(X)] # 添加噪声
# 方便绘制曲线,所以创建多一些点
x_ = np.linspace(0, 1, 100)
y_ = func(x_)
调用拟合函数,并进行绘图:
fit_pars = fitting(3)[0] # 注意返回值中的第一行才是拟合曲线的参数列表
plt.plot(x_, y_, label='real line')
plt.scatter(X, Y, label='real points')
plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line')
plt.legend()
plt.show()
p=3的时候的图像:
当然,这里我直接传入p=3,也就是建立3次的曲线对数据点进行拟合,如果传入的p=1的时候,图像如下:
如果p=2,则是:
可以看到没有变化,也就是说没办法找到一条二次曲线,使得二次误差少于上面的一次曲线了。
完整代码如下:
import numpy as np
from scipy.optimize import leastsq
import matplotlib.pyplot as plt
# 数据点分布在这条曲线附近
def func(x):
return 2*np.sin(2*np.pi*x)
# 误差函数, 计算拟合曲线与真实数据点之间的差 ,作为leastsq函数的输入
def residuals(p, x, y):
fun = np.poly1d(p) # poly1d()函数可以按照输入的列表p返回一个多项式函数
return y - fun(x)
# 拟合函数
def fitting(p):
pars = np.random.rand(p+1) # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p
r = leastsq(residuals, pars, args=(X, Y)) # 三个参数:误差函数、函数参数列表、数据点
return r
# 要进行拟合的数据点
X = np.linspace(0, 1, 10)
Y = [np.random.normal(0, 0.1)+num for num in func(X)] # 添加噪声
# 方便绘制曲线,所以创建
x_ = np.linspace(0, 1, 100)
y_ = func(x_)
# print(fitting(3)) 可以看一下返回的是什么
fit_pars = fitting(3)[0]
plt.plot(x_, y_, label='real line')
plt.scatter(X, Y, label='real points')
plt.plot(x_, np.poly1d(fit_pars)(x_), label='fitting line')
plt.legend()
plt.show()
以上~