在 pytorch 中,如何快速从一个均值和协方差矩阵已知的多元正态分布中采样多个向量?

问题背景

Pytorch~多组多元正态分布_人工智能

这涉及到正态总体的无偏估计,不记得的同学请出门左拐重修概率论(狗头)

Pytorch~多组多元正态分布_线性变换_02

问题分析

样本均值和样本离差阵都可以通过 torch 的内置方法,比如按列求和快速获得,那么问题就划归为,在 pytorch 中,如何快速从一个均值和协方差矩阵已知的多元正态分布中采样多个向量?

torch 的内置多元正态分布模块 distributions

其实 pytorch 本身就提供了生成多元正态分布的模块。假设我们需要采样的正态分布为:

Pytorch~多组多元正态分布_人工智能_03

mu = torch.FloatTensor([1, 2, 0])
sigma = torch.FloatTensor([
    [2, 0, 0],
    [0, 5, 0],
    [0, 0, 1]
])

那么假设我们需要从中采样100000个样本,那么代码如下:

sampler = torch.distributions.MultivariateNormal(
    loc=mu, covariance_matrix=sigma
)
samples: torch.Tensor = sampler.sample((100000, ))
# samples.shape [100000, 3]

我们可以简单验证一下这个是否正确,只需要重新计算采样的均值和协方差在小幅度内是否和总体的相等即可:

new_mu = samples.mean(dim=0)
new_sigma = (samples - mu).T @ (samples - mu) / len(samples)

print(new_mu.round())
print(new_sigma.round())

输出:

tensor([1., 2., 0.])
tensor([[2., 0., -0.],
        [0., 5., 0.],
        [-0., 0., 1.]])

很完美,说明这个内置采样函数是没问题的。而且锦恢测量了这个采样函数的运算效率,效率非常高,10万个点从建立分布再到采样,只花了10 ms,性能非常优秀。

但是这个做法有一个缺点,那就是不够自由,因为每次实例化这个类只能生成一个正态分布的采样,而在我朋友的例子里面,我们需要同时生成 B 个正态分布的 m 组采样。使用这个方法,就不可避免地需要进行 for 循环,在 B 比较大时,类的初始化和销毁也是一笔不小的开销,那么有没有更加优雅的方法呢?

换一个角度思考

我们不妨换个角度思考这个问题, pytorch 可以基于 torch.randn 和 torch.randn_like 来生成任意维度的正态分布。

既然要生成的是一个任意尺度的多元正态分布,基于小学知识,我们知道,任意多元正态分布都是标准正态分布的线性变换,原则上,如果我们可以获得标准正态分布到目标多元正态分布的线性变换,那么这个问题就迎刃而解了。

Pytorch~多组多元正态分布_正态分布_04

基于线性变换的多元正态分布生成算法

简单总结一下步骤:

Pytorch~多组多元正态分布_正态分布_05

我们不妨通过编程的方法来验证这个算法的正确性(mu和sigma还是上面的那两个,此处就不初始化了):

b = mu
A = torch.linalg.cholesky(sigma)
X = torch.randn((100000, 3))
samples = X @ A.T + b

同理,验证一下正确性:

new_mu = samples.mean(dim=0)
new_sigma = (samples - new_mu).T @ (samples - new_mu) / len(samples)

print(new_mu.round())
print(new_sigma.round())

输出:

tensor([1., 2., 0.])
tensor([[2., -0., 0.],
        [-0., 5., 0.],
        [0., 0., 1.]])

看起来完全没问题!说明我们的方法没有问题。锦恢也简单测试了一下性能,生成10万个采样只需要 3-4 ms。

拓展到多个多元正态分布的生成

此时,有些读者就显得不耐烦了:你这个算法看起来和 torch 的内置算法看起来差不多呀,不是也只能一次性生成一个吗?但是请阁下不要着急(后面有得你急的bushi),我们的这个算法中只有矩阵算法,我们将多元正态分布的生成从一个类的实例化转化到了矩阵运算,这意味着什么?这意味着我们可以利用循环向量化的思想,将 for 循环 + torch.distributions.MultivariateNormal 的写法基于 torch 的 GEMM 实现。这里引入一个热知识:

torch.matmul 是实现矩阵乘法的函数,但是如果输入的 tensor 是三维的,比如 [B, m, k] 和 [B, k, n],那么 torch.matmul 的输出结果是会忽略第一个维度,输出 [B, m, n] 的 tensor,利用这个特性,我们就可以在不写 for 循环和额外类实例化的情况下几行实现我朋友的需求。

基于线性变换的多元正态分布生成 代码实现

为了模拟我朋友的问题,我暂时用 sklearn 内置的 iris 数据集进行演示。我们将 iris 分割为5份,组成一个 [5, 30, 4] 的 tensor:

import torch
import numpy as np
from sklearn.datasets import load_iris

# 构造一个 5 * 30 * 4 的三阶张量
X, y = load_iris(return_X_y=True)
sample_num, feat_num = X.shape
batch_size = 5
three_d_tensor = X.reshape(batch_size, sample_num // batch_size, feat_num)
three_d_tensor = torch.from_numpy(three_d_tensor)

然后计算每个 batch 的正态分布重采样,也就是返回一个 [5, 30, 4] 的 tensor。这里我们实现最核心的多个多元正态并行生成算法 make_normal_along_batch

def make_normal_along_batch(tensors: torch.FloatTensor) -> torch.FloatTensor:
    assert len(tensors.shape) == 3
    sample_n = tensors.shape[1]
    mus = torch.mean(tensors, dim=1).unsqueeze(dim=1)
    nor_tensors = tensors - mus
    cov = torch.matmul(nor_tensors.permute(0, 2, 1), nor_tensors) / sample_n
    A: torch.FloatTensor = torch.linalg.cholesky(cov)
    X = torch.randn_like(tensors)
    Y = torch.matmul(X, A.permute(0, 2, 1)) + mus
    return Y

最后测试一下结果:

normals = make_normal_along_batch(three_d_tensor)
print(normals.shape)

输出:(感兴趣的读者可以自行验证结果的数值正确性)

torch.Size([5, 30, 4])

至此,我们完成了这个任务,可以进行下一个任务了。