写在前面的一些话
因为自己项目需要,以及总是听说扩散模型,所以自己去b站看了视频,尽量写的通俗易懂,致力于高效省时的帮助大家搞明白扩散模型的原理,让小白也能读懂这篇论文
注意!!文章可能涉及比较多公式,但不要害怕!!结合我的说明,看懂没问题的,一步步来!不要着急,不要跳步!如果有错误,欢迎指正!有什么问题欢迎在评论区讨论!
简述
最近经常听说扩散模型,甚至可以打败GAN。回顾GAN,我们需要同时训练生成器和判别器,可能会难以收敛以及学习到一些我们不想要的特征。而diffusion model做的事情就是用了一种更简单的方式来解释生成模型应该怎么学习和生成。diffusion model火起来是因为DALLE 2的出现(也是openai的,跟chatgpt出自一个公司),实现文字转图片,能得到非常惊艳的效果(如下图,生成一个牛油果形状的沙发),可以自行搜索一下他们的网站
整个diffusion model可以分为两部分,一个是前向扩散过程,另一个是逆扩散过程,通俗理解为:前向扩散过程不停的往图片上加服从高斯分布的噪声,加到使图片变得“面目全非”(下图从右到左),逆扩散过程就是不停的减噪声然后复原成图片(从左到右)
在原论文中,扩散过程需要进行2000次加噪声的步骤,实际操作中大约200-500次。在扩散过程中,每次往图片上加的噪声就是逆过程的标签,接下来我会分别解释前向扩散过程以及逆扩散过程
前向扩散过程 forward diffusion
前面说到,扩散过程简单来说就是不停的往图片里加噪声,把图片加的面目全非。那怎么加,加多少呢?论文中给出核心公式:
这个公式怎么来的呢?别急,我们一步步来看这个公式
首先,值得一提的是,整个扩散模型是符合马尔可夫定理的,也就是说t时刻的分布只与t-1时刻有关,所以为什么公式里只出现了而没有,,…
其次,是一个经验常量,且会随着t的增大而减小,这是实验前决定的; (包括文章后面出现的,…)都是服从标准高斯分布的噪声~N(0,I) 。由此,我们可以将这个公式理解为一部分的加上了一部分的,也就是说,等于前一时刻的分布和标准高斯分布的权重和,而这个权重由决定。因为随着t的增大会减小,所以的权重会越来越小,的权重会越来越大。因此随着t的增大,噪声占比越来越大,前一时刻的分布占比越来越小。
好了,到这里我们搞懂了其中一个核心公式。但有一个问题,假如我加噪声加了1000次,我要是想知道第一千次的分布,难道要从第一步开始一步步往后推吗,知道了我才能知道,知道我才能知道?这也太慢了吧。因此论文又给了我们另一个公式:
这又是怎么来的呢?接下来慢慢解释。
让我们先根据公式(1)写出的公式 (把(1)中的t换成t-1就行了);
再结合(2)(3)写出(不用看(1),直接把(3)中的带入到(2)):
把乘进去,括号移一下:
到这里应该没什么难点,只是简单的代入。我们仔细观察一下公式(5),发现括号内是两个高斯分布相加(记住:都是服从标准的高斯分布),我们知道两个高斯分布相加还是高斯分布,具体推导可以参考另一个博主的博客。那么括号里加出来新的高斯分布具体是什么呢?
如果我们把公式(5)中的看作一整个分布,那:
同理对于公式(5)中的:
插一句,这是因为,在一个高斯分布前面乘上一个系数相当于改变它的标准差,给一个高斯分布加上或减去某个数相当于改变它的均值
ok,那根据博客里推导的公式,两个高斯分布相加后新的高斯分布应为:
我们把(6)(7)代入(8)一下,可得到新的分布具体为 (就是把(6)(7)的标准差加起来):
为了保持一致,我们就写成:
那公式(5)就可以改写成 (只用把括号里的改成我们新推导出来的高斯分布就可以了):
如果没懂就多看几遍公式,没有难点,只是代入可能会比较绕。如果懂了就接着往下面看
我们来比较一下公式(1)和(10),我再把他们写出来:
你发现了什么?是不是发现这个公式似乎可以类推?每往前推一项,只需要在权重的根号下多乘一项就可以了?那你应该能猜到如何用来表示了,就是公式(2):
OK,到这里我们就推导出了扩散过程的一个重要的核心公式,有了公式(2),我只需要知道起始的以及现在到了第几步,也就是t,我就可以直接算出是多少。扩散过程到这里就结束了,如果没懂的就回顾几遍,接下来要开始说逆扩散过程咯。
逆扩散过程 reverse diffusion
介绍完扩散过程,我们现在来说复原过程。再把这个图拿出来看。我们现在知道怎么算,也就是扩散过程,但不知道怎么算。括号里的两个变量换了位置,这使我们很容易就联想到贝叶斯公式。最常用的贝叶斯公式如下:
等式两边同时引入这个变量:
公式(12)详细证明可以看这里,我自己推的,不想看可以略过,不会影响后面理解:
证明用到了条件概率公式:
由此可证:
公式(12)右边总共由三项组成:
我们一项一项来看:
- :这个式子表示已知和去计算,是不是觉得很眼熟?对,就是我们在一开始介绍前向扩散过程的时候给出的公式,知道前一项算后一项。也就是公式(1)所表示的: 。在这里似乎已知看起来没有存在的必要,但这是为了接下来计算后面的两项。
- :理解一下这个式子表示的意思,就是已知去求,听起来好像也很熟悉。这不就是我们之前在前向扩散过程中推导出来的公式吗?若已知和时间t,我们可以根据公式(2)直接求得:
- :同理,已知去求,只需要把公式(2)中的t替换成t-1就行了:
总结,公式(12)右边三项:
Again,在一个高斯分布前面乘上一个系数相当于改变它的标准差,给一个高斯分布加上或减去某个数相当于改变它的均值
然后我们发现了什么?公式(12)的右边三项全都是高斯分布!!! 这意味着我们又可以进行一波操作来化简等式的右边
我们再来看一眼公式(12):
发现等式右半边就是一个高斯分布乘上另一个高斯分布再除以一个高斯分布。我们知道高斯分布长这个样子:
所以可想而知,对于指数部分,两项相乘等于指数相加,相除等于指数相减。
因此等式(12)的右半边(这一步对照着公式(12)和(13)~(15)看就明白了,除的前面就是负号,乘的前面就是加号,然后把均值和标准差代入):
接下来繁琐的展开和合并同类项,可以选择跳过,只需要知道我们得到了公式(16) (为了简写:)
到这其实可以看出来是个完全平方式,也就是:
又因为我们想把他表示成一个高斯分布且:
对比着(17)(18)来看,我们发现其实A就是,是个定值,只需要知道就知道,而从B整合一下就可以得到:
注意这里的代表的是逆扩散过程中,我们根据猜出来的分布的均值和方差。
是不是觉得,唉,怎么又是这么长的式子,稳住!马上快结束了!!!
不知道大家有没有发现我们求出来的均值(20)有什么不对劲的地方。我们先快速回顾一下我们在解决什么问题。之前的前向扩散是已知去求,而我们现在想求的是逆扩散过程,也就是已知去求,最终目标是能得到也就是复原后的图片。但是式子(20)里面告诉我们,求均值需要用到,这该咋办呢?只好借助前向扩散过程中的公式(2),用估计:
公式(2):
移项可得:
把公式(21)代入(20)并整理一下:
终于,我们好像知道了如何在已知的情况下去估计的分布。不过还剩一个问题,是啥啊?这里可不一定是高斯分布了,我们只在前向扩散过程中定义它为服从高斯分布的噪声,可没说在逆扩散过程中也定义为高斯分布啊。那怎么办呢?这时候,该轮到我们的机器学习闪亮登场了!!公式无法推导的,那就上机器学习暴力求解呗!我们需要这样一个模型,输入为,输出为噪声。
那既然用到机器学习,我们就需要标签,这里的标签就是在前向扩散过程的每一次迭代中,我们从高斯分布中采样得到的噪声,这个噪声我们是可以在前向扩散过程中记录下来的,因为前向加了什么噪声我们肯定知道嘛。那在逆扩散过程中,我们只需要把前向扩散过程中对应记录下来的噪声作为我们的标签,就可以训练模型,使得模型根据预测,从而根据计算t-1时刻的分布的均值 (公式(22)),又因为方差是个定值 (公式(19)),我们就可以求得t-1时刻的分布了。
我自己画了个图,帮助理解:
所以总结来说,逆扩散过程就是使预测出来的噪声和前向扩散过程中加的噪声距离越小越好。
终于!!!把前向扩散过程和逆扩散过程都解释完了!!也恭喜大家看到了这里!!撒花!!如果觉得绕也很正常,看多几遍就好啦!!
最后附上论文里的训练和预测流程图:
在training阶段的第五步就是主要用到了公式(2),求得了t时刻的,同时值得注意的是,传入模型的还有此刻的t。因为之前提到会随着t增大而减小,于是噪声也会加的越来越大,所以把t传入相当于多提供一些信息帮助训练模型来预测噪声。在sampling过程中,step4用到了公式(19)和(22),用预测的噪声算出前一时刻的分布,一步步向前传,直到预测出。
总结
扩散模型分为前向扩散过程和逆扩散过程。前向扩散过程是从算到,逆扩散是从算到。
在前扩散过程中,主要做的事情就是迭代的往图片上加服从高斯分布的噪声,从逐渐加到,我们在这一部分逐步推导了公式,核心公式为公式(1)和(2)。公式(1)告诉我们如何从算出,而公式(2)告诉我们如何根据和时刻t算出。
在逆扩散过程中,我们先利用贝叶斯公式,将前一时刻的分布,转化为可利用前向扩散过程中的公式计算的式子,发现转换后的公式主要由三个高斯分布组成,而这三个高斯分布可以组合成一个新的高斯分布,也就是前一时刻的分布。之后经过一系列化简合并,我们将前一时刻的分布的均值和方差求了出来。不过前一时刻分布的均值里面包含了一个我们无法直接求得的,所以我们需要借助机器学习去估计这个,然后利用预测出来的求得前一时刻的分布,然后逐步迭代直到算出。核心公式为公式(19)和(22)。