导读

本文解析了高斯过程进行公式推导、原理阐述、可视化以及代码实现,并介绍了高斯过程回归基本原理、超参优化、高维输入等问题。

高斯过程 Gaussian Processes 是概率论和数理统计中随机过程的一种,是多元高斯分布的扩展,被应用于机器学习、信号处理等领域。本文对高斯过程进行公式推导、原理阐述、可视化以及代码实现,介绍了以高斯过程为基础的高斯过程回归 Gaussian Process Regression 基本原理、超参优化、高维输入等问题。

目录

  • 一元高斯分布
  • 多元高斯分布
  • 无限元高斯分布?
  • 核函数(协方差函数)
  • 高斯过程可视化
  • 高斯过程回归实现
  • 超参数优化
  • 多维输入
  • 高斯过程回归的优缺点

一元高斯分布

高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习

多元高斯分布

高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习_02

无限元高斯分布?

在多元高斯分布的基础上考虑进一步扩展,假设有无限多维呢?用一个例子来展示这个扩展的过程(MLSS 2012: J. Cunningham - Gaussian Processes for Machine Learning),假设我们在周一到周四每天的 7:00 测试了 4 次心率,如下图中 4 个点,可能的高斯分布如图所示(高瘦的那条)。这是一个一元高斯分布,只有每天 7: 00 的心率这个维度。



高斯数据库 达梦数据库 mysql 高斯数据库教程_python_03

现在考虑不仅在每天的 7: 00 测心率(横轴),在 8:00 时也进行测量(纵轴),这个时候变成两个维度(二元高斯分布),如下图所示



高斯数据库 达梦数据库 mysql 高斯数据库教程_python_04

更进一步,如果我们在每天的无数个时间点都进行测量,则变成了下图的情况。注意下图中把测量时间作为横轴,则每个颜色的一条线代表一个(无限个时间点的测量)无限维的采样。当对每次对无限维进行采样得到无限多个点时,其实可以理解为我们采样得到了一个函数。



高斯数据库 达梦数据库 mysql 高斯数据库教程_高斯数据库 达梦数据库 mysql_05

高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习_06

高斯核函数的 python 实现如下

import numpy as np

def gaussian_kernel(x1, x2, l=1.0, sigma_f=1.0):
    """Easy to understand but inefficient."""
    m, n = x1.shape[0], x2.shape[0]
    dist_matrix = np.zeros((m, n), dtype=float)
    for i in range(m):
        for j in range(n):
            dist_matrix[i][j] = np.sum((x1[i] - x2[j]) ** 2)
    return sigma_f ** 2 * np.exp(- 0.5 / l ** 2 * dist_matrix)

def gaussian_kernel_vectorization(x1, x2, l=1.0, sigma_f=1.0):
    """More efficient approach."""
    dist_matrix = np.sum(x1**2, 1).reshape(-1, 1) + np.sum(x2**2, 1) - 2 * np.dot(x1, x2.T)
    return sigma_f ** 2 * np.exp(-0.5 / l ** 2 * dist_matrix)

x = np.array([700, 800, 1029]).reshape(-1, 1)
print(gaussian_kernel_vectorization(x, x, l=500, sigma=10))

输出的向量  与自身的协方差矩阵为

[[100.    98.02  80.53]
 [ 98.02 100.    90.04]
 [ 80.53  90.04 100.  ]]

高斯过程可视化

下图是高斯过程的可视化,其中蓝线是高斯过程的均值,浅蓝色区域 95% 置信区间(由协方差矩阵的对角线得到),每条虚线代表一个函数采样(这里用了 100 维模拟连续无限维)。左上角第一幅图是高斯过程的先验(这里用了零均值作为先验),后面几幅图展示了当观测到新的数据点的时候,高斯过程如何更新自身的均值函数和协方差函数。



高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_07

高斯数据库 达梦数据库 mysql 高斯数据库教程_高斯数据库 达梦数据库 mysql_08

简单高斯过程回归实现

考虑代码实现一个高斯过程回归,API 接口风格采用 sciki-learn fit-predict 风格。由于高斯过程回归是一种非参数化 (non-parametric)的模型,每次的 inference 都需要利用所有的训练数据进行计算得到结果,因此并没有一个显式的训练模型参数的过程,所以 fit 方法只需要将训练数据保存下来,实际的 inference 在 predict 方法中进行。Python 代码如下

from scipy.optimize import minimize


class GPR:

    def __init__(self, optimize=True):
        self.is_fit = False
        self.train_X, self.train_y = None, None
        self.params = {"l": 0.5, "sigma_f": 0.2}
        self.optimize = optimize

    def fit(self, X, y):
        # store train data
        self.train_X = np.asarray(X)
        self.train_y = np.asarray(y)
        self.is_fit = True

    def predict(self, X):
        if not self.is_fit:
            print("GPR Model not fit yet.")
            return

        X = np.asarray(X)
        Kff = self.kernel(self.train_X, self.train_X)  # (N, N)
        Kyy = self.kernel(X, X)  # (k, k)
        Kfy = self.kernel(self.train_X, X)  # (N, k)
        Kff_inv = np.linalg.inv(Kff + 1e-8 * np.eye(len(self.train_X)))  # (N, N)
        
        mu = Kfy.T.dot(Kff_inv).dot(self.train_y)
        cov = Kyy - Kfy.T.dot(Kff_inv).dot(Kfy)
        return mu, cov

    def kernel(self, x1, x2):
        dist_matrix = np.sum(x1**2, 1).reshape(-1, 1) + np.sum(x2**2, 1) - 2 * np.dot(x1, x2.T)
        return self.params["sigma_f"] ** 2 * np.exp(-0.5 / self.params["l"] ** 2 * dist_matrix)
def y(x, noise_sigma=0.0):
    x = np.asarray(x)
    y = np.cos(x) + np.random.normal(0, noise_sigma, size=x.shape)
    return y.tolist()

train_X = np.array([3, 1, 4, 5, 9]).reshape(-1, 1)
train_y = y(train_X, noise_sigma=1e-4)
test_X = np.arange(0, 10, 0.1).reshape(-1, 1)

gpr = GPR()
gpr.fit(train_X, train_y)
mu, cov = gpr.predict(test_X)
test_y = mu.ravel()
uncertainty = 1.96 * np.sqrt(np.diag(cov))
plt.figure()
plt.title("l=%.2f sigma_f=%.2f" % (gpr.params["l"], gpr.params["sigma_f"]))
plt.fill_between(test_X.ravel(), test_y + uncertainty, test_y - uncertainty, alpha=0.1)
plt.plot(test_X, test_y, label="predict")
plt.scatter(train_X, train_y, label="train", c="red", marker="x")
plt.legend()

结果如下图,红点是训练数据,蓝线是预测值,浅蓝色区域是 95% 置信区间。真实的函数是一个 cosine 函数,可以看到在训练数据点较为密集的地方,模型预测的不确定性较低,而在训练数据点比较稀疏的区域,模型预测不确定性较高。



高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习_09

超参数优化

高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习_10


高斯数据库 达梦数据库 mysql 高斯数据库教程_机器学习_11

高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_12

from scipy.optimize import minimize


class GPR:

    def __init__(self, optimize=True):
        self.is_fit = False
        self.train_X, self.train_y = None, None
        self.params = {"l": 0.5, "sigma_f": 0.2}
        self.optimize = optimize

    def fit(self, X, y):
        # store train data
        self.train_X = np.asarray(X)
        self.train_y = np.asarray(y)

         # hyper parameters optimization
        def negative_log_likelihood_loss(params):
            self.params["l"], self.params["sigma_f"] = params[0], params[1]
            Kyy = self.kernel(self.train_X, self.train_X) + 1e-8 * np.eye(len(self.train_X))
            return 0.5 * self.train_y.T.dot(np.linalg.inv(Kyy)).dot(self.train_y) + 0.5 * np.linalg.slogdet(Kyy)[1] + 0.5 * len(self.train_X) * np.log(2 * np.pi)

        if self.optimize:
            res = minimize(negative_log_likelihood_loss, [self.params["l"], self.params["sigma_f"]],
                   bounds=((1e-4, 1e4), (1e-4, 1e4)),
                   method='L-BFGS-B')
            self.params["l"], self.params["sigma_f"] = res.x[0], res.x[1]

        self.is_fit = True

高斯数据库 达梦数据库 mysql 高斯数据库教程_python_13



高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_14

这里用 scikit-learn 的 GaussianProcessRegressor 接口进行对比

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ConstantKernel, RBF

# fit GPR
kernel = ConstantKernel(constant_value=0.2, constant_value_bounds=(1e-4, 1e4)) * RBF(length_scale=0.5, length_scale_bounds=(1e-4, 1e4))
gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=2)
gpr.fit(train_X, train_y)
mu, cov = gpr.predict(test_X, return_cov=True)
test_y = mu.ravel()
uncertainty = 1.96 * np.sqrt(np.diag(cov))

# plotting
plt.figure()
plt.title("l=%.1f sigma_f=%.1f" % (gpr.kernel_.k2.length_scale, gpr.kernel_.k1.constant_value))
plt.fill_between(test_X.ravel(), test_y + uncertainty, test_y - uncertainty, alpha=0.1)
plt.plot(test_X, test_y, label="predict")
plt.scatter(train_X, train_y, label="train", c="red", marker="x")
plt.legend()

高斯数据库 达梦数据库 mysql 高斯数据库教程_算法_15



高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_16

多维输入

我们上面讨论的训练数据都是一维的,高斯过程直接可以扩展于多维输入的情况,直接将输入维度增加即可。

def y_2d(x, noise_sigma=0.0):
    x = np.asarray(x)
    y = np.sin(0.5 * np.linalg.norm(x, axis=1))
    y += np.random.normal(0, noise_sigma, size=y.shape)
    return y

train_X = np.random.uniform(-4, 4, (100, 2)).tolist()
train_y = y_2d(train_X, noise_sigma=1e-4)

test_d1 = np.arange(-5, 5, 0.2)
test_d2 = np.arange(-5, 5, 0.2)
test_d1, test_d2 = np.meshgrid(test_d1, test_d2)
test_X = [[d1, d2] for d1, d2 in zip(test_d1.ravel(), test_d2.ravel())]

gpr = GPR(optimize=True)
gpr.fit(train_X, train_y)
mu, cov = gpr.predict(test_X)
z = mu.reshape(test_d1.shape)

fig = plt.figure(figsize=(7, 5))
ax = Axes3D(fig)
ax.plot_surface(test_d1, test_d2, z, cmap=cm.coolwarm, linewidth=0, alpha=0.2, antialiased=False)
ax.scatter(np.asarray(train_X)[:,0], np.asarray(train_X)[:,1], train_y, c=train_y, cmap=cm.coolwarm)
ax.contourf(test_d1, test_d2, z, zdir='z', offset=0, cmap=cm.coolwarm, alpha=0.6)
ax.set_title("l=%.2f sigma_f=%.2f" % (gpr.params["l"], gpr.params["sigma_f"]))

下面是一个二维输入数据的高斯过程回归,左图是没有经过超参优化的拟合效果,右图是经过超参优化的拟合效果。



高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_17

以上相关的代码放在 toys/GP 。

高斯过程回归的优缺点

高斯数据库 达梦数据库 mysql 高斯数据库教程_可视化_18

参考资料

1.Carl Edward Rasmussen - Gaussian Processes for Machine Learning
https://www.gaussianprocess.org/gpml/chapters/RW.pdf

2.MLSS 2012 J. Cunningham - Gaussian Processes for Machine Learning
https://www.columbia.edu/~jwp2128/Teaching/E6892/papers/mlss2012_cunningham_gaussian_processes.pdf
3.Martin Krasser's blog- Gaussian Processes
https://krasserm.github.io/2018/03/19/gaussian-processes/
4.scikit-learn GaussianProcessRegressor
https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessRegressor.html