一 扩散模型原理记录
以下内容为对上述资料的补充理解,理解不对的地方,请多指教。
以下序号与资料中的章节序号一致。
七、目标数据分布的似然函数
扩散模型本质为生成模型,所以最本质的目标是最大化对数据分布真值的预测概率。
这里可以假设成一个分类问题,不同的类别表示不同的数据分布,其中包括与数据分布真值相近的和不相近的。模型会预测不同数据分布的概率。我们的目标是,使网络对数据分布真值对应的类别的预测概率最高。
用公式表示:,其中,为模型对数据分布真值预测的概率分布(注意模型不只是网络,在扩散模型里,网络是模型的一部分,模型还包括对网络输出结果的后处理,因此网络输出值可能多种多样)。
但是范围是,直接最大化不好计算,因此一般转化为最小化对数似然函数:。直接最小化也不好求,所以扩散模型转而最小化的上界
,这个上界
就是(需要乘)。
下面的目标就是最小化。
最终转化为(与合并到一起了),其中,和都是两个高斯分布的KL散度,结果只与两个高斯分布的均值和方差有关。中两个分布的均值和方差都是已知(在分布已知的情况下已知)且不可优化的,因此直接去除。下面计算,如下式(方差是设定的固定值,所以省略了):
其中,是高斯分布的均值,是高斯分布的均值。
是模型的预测分布,也可以写成。
对上式展开,其中的均值已经在前面计算出来了,直接代入:
上式中与上文的一样,都是加的噪声。下面的问题是,我们要最小化,网络在模型中扮演什么角色?可选择的是:
- 预测,使其逼近,即损失是他俩的差;
- 预测,使其直接逼近,损失是他俩的差;
- 预测,这样分布的均值就与的均值公式一样,即下式。这样就可以逼近,即损失是他俩的差(可以简化计算);
扩散模型的作者选择用网络来预测,这样,的计算公式如下:
再简化,如下:
到这里,网络的损失就确定了,即最小化预测的噪声
与实际添加的噪声
的差,网络输入是时刻t
和时刻t对应的xt
。
有了网络输出的噪声后,就可以通过分布的均值和方差(方差是预定义的)来采样出,训练过程和反扩散过程的伪代码如下:
反扩散过程用到了重参数化采样,上图中的就是标准差。
二 问题记录
2.1 正向扩散过程的高斯均值和方差为什么这么设计?
是为了让扩散后的数据分布接近正态分布而特意设计的。