写在前面

大家最近应该一直都有刷到ChatGPT的相关文章。小喵之前也有做过相关分享,后续也会出文章来介绍ChatGPT背后的算法——RLHF

考虑到RLHF算法的第三步~通过强化学习微调语言模型的目标损失函数中有一项是KL散度,所以今天就先给大家分享一篇与KL散度相关的文章

0. KL散度概述

KL散度(Kullback-Leibler Divergence,KL Divergence)是一种量化两种概率分布P和Q之间差异的方式,又叫相对熵

在概率学和统计学上,我们经常会使用一种更简单的、近似的分布来替代观察数据太复杂的分布。KL散度能帮助我们度量使用一个分布来近似另一个分布时所损失的信息量

1. 外形蠕虫牙齿数量分布

KL散度定义见下文8.1节。另外在下文8.5节中解释了为什么在深度学习中,训练模型时使用的是交叉熵损失(Cross Entropy)而非KL散度。

我们从下面这个问题出发思考KL散度:

假设我们是一群太空科学家,经过遥远的旅行,来到了一颗新发现的星球。在这个星球上,生存着一种长有牙齿的蠕虫,引起了我们的研究兴趣。我们发现这种蠕虫生有10颗牙齿,但是因为不注意口腔卫生,又喜欢嚼东西,许多蠕虫会掉牙。收集大量样本之后,我们得到关于蠕虫牙齿数量的经验分布,如下图所示



python KL散度函数 kl散度求导_概率分布

牙齿数量分布

python KL散度函数 kl散度求导_python KL散度函数_02

会掉牙的外星蠕虫

这些数据很有价值,但是也有点问题。我们距离地球🌍太远了,把这些概率分布数据发送回地球过于昂贵。还好我们是一群聪明的科学家,用一个只有一两个参数的简单模型来近似原始数据会减小数据传送量。最简单的近似模型是 均分布,因为蠕虫牙齿不会超过10颗,所以有11个可能值,那蠕虫的牙齿数量概率都为 1/11。分布图如下:



python KL散度函数 kl散度求导_概率分布_03

uniform distribution/均分布

显然我们的原始数据并非均分布的,但也不是我们已知的分布,至少不是常见的分布。作为备选,我们想到的另一种简单模型是二项式分布(binomial distribution)。蠕虫嘴里面共有 个牙槽,每个牙槽出现牙齿与否为独立事件,且概率均为 。则蠕虫牙齿数量即为期望值 ,真实期望值即为观察数据的平均值,比如说5.7,则 ,得到如下图所示的二项式分布:



python KL散度函数 kl散度求导_原始数据_04

binomial/二项式分布

对比一下原始数据,可以看出均分布和二项分布都不能完全描述原始分布。



python KL散度函数 kl散度求导_原始数据_05

分布对比

可是,我们不禁要问,哪一种分布更加接近原始分布呢?

已经有许多度量误差的方式存在,但是我们所要考虑的是减小发送的信息量。上面讨论的均分布和二项式分布都把问题规约到只需要两个参数,牙齿数量和概率值(均分布只需要牙齿数量即可)。那么哪个分布保留了更多的原始数据分布的信息呢?这个时候就需要KL散度登场了。

2. 数据的熵

KL散度源于信息论。信息论主要研究如何量化数据中的信息。最重要的信息度量单位是(Entropy),一般用 表示。分布的熵的公式如下:



python KL散度函数 kl散度求导_数据_06

Entropy/熵

上面对数没有确定底数,可以是 、 或 ,等等。如果我们使用以 为底的对数计算 值的话,可以把这个值看作是编码信息所需要的最少二进制位个数bits

上面空间蠕虫的例子中,信息指的是根据观察所得的经验分布给出的蠕虫牙齿数量。计算可以得到原始数据概率分布的熵值为3.12 bits。这个值只是告诉我们编码蠕虫牙齿数量概率的信息需要的二进制位bit的位数。

可是熵值并没有给出压缩数据到最小熵值的方法,即如何编码数据才能达到最优(存储空间最优)。优化信息编码是一个非常有意思的主题,但并不是理解KL散度所必须的。熵的主要作用是告诉我们最优编码信息方案的理论下界(存储空间),以及度量数据的信息量的一种方式。理解了熵,我们就知道有多少信息蕴含在数据之中,现在我们就可以计算当我们用一个带参数的概率分布来近似替代原始数据分布的时候,到底损失了多少信息。请继续看下节内容。

3. K-L散度度量信息损失

只需要稍加修改熵H的计算公式就能得到KL散度的计算公式。设 为观察得到的概率分布, 为另一分布来近似 ,则 、 的KL散度为:



python KL散度函数 kl散度求导_python KL散度函数_07

entropy-p-q

显然,根据上面的公式,KL散度其实是数据的原始分布 和近似分布 之间的对数差值的期望

如果继续用 为底的对数计算,则KL散度值表示信息损失的二进制位数。下面公式以期望表达KL散度:



python KL散度函数 kl散度求导_概率分布_08

DKL1

一般,KL散度以下面的书写方式更常见:



python KL散度函数 kl散度求导_原始数据_09

DKL2

注:

OK,现在我们知道当用一个分布来近似另一个分布时如何计算信息损失量了。接下来,让我们重新回到最开始的蠕虫牙齿数量概率分布的问题。

4. 两种分布对比

首先是用均分布来近似原始分布的KL散度:



python KL散度函数 kl散度求导_数据_10

DKL-uniform

接下来计算用二项式分布近似原始分布的KL散度:



python KL散度函数 kl散度求导_概率分布_11

DKL-binomial

通过上面的计算可以看出,使用均分布近似原始分布的信息损失要比用二项式分布近似小。所以,如果要从均分布和二项式分布中选择一个的话,均分布更好些。

5. KL散度并非距离

很自然地,一些同学把KL散度看作是不同分布之间距离的度量。这是不对的,因为从KL散度的计算公式就可以看出KL散度不符合对称性(距离度量应该满足对称性)。

如果用我们上面观察的数据分布来近似二项式分布,得到如下结果:

python KL散度函数 kl散度求导_原始数据_12

即有:

也就是说,用 近似 和用 近似 ,二者所损失的信息并不是一样的

6. 使用KL散度优化模型

前面使用的二项式分布的参数是概率 ,是原始数据的均值。 的值域在 之间,我们要选择一个 值,建立二项式分布,目的是最小化近似误差,即KL散度。那么 是最优的吗?

下图是原始数据分布和二项式分布的KL散度变化随二项式分布参数 变化情况:


python KL散度函数 kl散度求导_原始数据_13

二项分布K-L值变化曲线

通过上面的曲线图可以看出,KL散度值在圆点处最小,即 。所以我们之前的二项式分布模型已经是最优的二项式模型了。注意这里限定在二项式模型范围内, 是最优的。

前面只考虑了均分布模型和二项式分布模型,接下来我们考虑另外一种模型来近似原始数据。首先把原始数据分成两部分:1)0-5颗牙齿的概率;2)6-10颗牙齿的概率。概率值如下:


python KL散度函数 kl散度求导_原始数据_14

ad hoc model

即一只蠕虫的牙齿数量 的概率为 ; 的概率为 ,; 。

Aha,我们自己建立了一个新的(奇怪的)模型来近似原始的分布,模型只有一个参数 ,像前面那样优化二项式分布的时候所做的一样,让我们画出KL散度值随 变化的情况:


python KL散度函数 kl散度求导_数据_15

finding an optimal parameter value for our ad hoc model

当 时,KL值取最小值 。似曾相识吗?对,这个值和使用均分布的K-L散度值是一样的(这并不能说明什么)!下面我们继续画出这个奇怪模型的概率分布图,看起来确实和均分布的概率分布图相似:


python KL散度函数 kl散度求导_概率分布_16

ad hoc model distribution

我们自己都说了,这是个奇怪的模型,在KL值相同的情况下,更倾向于使用更常见的、更简单的均分布模型。

回头看,我们在这一小节中使用KL散度作为目标方程,分别找到了二项式分布模型的参数 和上面这个随手建立的模型的参数 。

是的,这就是本节的重点:使用KL散度作为目标方程来优化模型。当然,本节中的模型都只有一个参数,也可以拓展到有更多参数的高维模型中。

7. 变分自编码器VAEs和变分贝叶斯法

如果你熟悉神经网络,你可能已经猜到我们接下来要学习的内容。除去神经网络结构的细节信息不谈,整个神经网络模型其实是在构造一个参数数量巨大的函数(百万级,甚至更多),不妨记为 ,通过设定目标函数,可以训练神经网络逼近非常复杂的真实函数 。训练的关键是要设定目标函数,反馈给神经网络当前的表现如何。训练过程就是不断减小目标函数值的过程。

我们已经知道KL散度用来度量在逼近一个分布时的信息损失量。KL散度能够赋予神经网络近似表达非常复杂数据分布的能力。变分自编码器(Variational Autoencoders,VAEs)是一种能够学习最佳近似数据集中信息的常用方法,Tutorial on Variational Autoencoders 2016[1]是一篇关于VAEs的非常不错的教程,里面讲述了如何构建VAE的细节。What are Variational Autoencoders\? A simple explanation[2]简单介绍了VAEs,Building Autoencoders in Keras[3]介绍了如何利用Keras库实现几种自编码器。

变分贝叶斯方法(Variational Bayesian Methods)是一种更常见的方法。Monte Carlo Simulations in R[4]介绍了强大的蒙特卡洛模拟方法能够解决很多概率问题。蒙特卡洛模拟能够帮助解决许多贝叶斯推理问题中的棘手积分问题,尽管计算开销很大。包括VAE在内的变分贝叶斯方法,都能用KL散度生成优化的近似分布,这种方法对棘手积分问题能进行更高效的推理。更多变分推理(Variational Inference)的知识可以访问Edward library for python[5]

8. 附录

8.1  KL散度的定义


python KL散度函数 kl散度求导_原始数据_17

KL 散度的定义

8.2 计算KL的注意事项


python KL散度函数 kl散度求导_概率分布_18

计算KL的注意事项

8.3 遇到log 0时怎么办


python KL散度函数 kl散度求导_原始数据_19

example for K-L smoothing

8.4信息熵、交叉熵、相对熵

  • 信息熵,即熵,香浓熵。编码方案完美时,最短平均编码长度。
  • 交叉熵,cross-entropy。编码方案不一定完美时(由于对概率分布的估计不一定正确),平均编码长度。是神经网络常用的损失函数。
  • 相对熵,即K-L散度,relative entropy。编码方案不一定完美时,平均编码长度相对于最小值的增加值。

更详细对比,见知乎如何通俗的解释交叉熵与相对熵\?[6]

8.5 为什么在神经网络中使用交叉熵损失函数,而不是K-L散度?

K-L散度=交叉熵-熵,即 。

在神经网络所涉及到的范围内, 不变,则 等价 。
更多讨论见《Why do we use Kullback-Leibler divergence rather than cross entropy in the t-SNE objective function?》[7]和《Why train with cross-entropy instead of KL divergence in classification?》[8]


作者:Aspirinrin