BLURRING DIFFUSION MODELS

Emiel Hoogeboom, Google Research, Brain Team, ICLR2023, Cited:11, Code: 无, Paper.

1. 前言

最近,提出了一种基于散热或模糊的生成建模的新型扩散过程,作为各向同性高斯扩散的替代方案。在这里,我们表明模糊可以等价地通过具有非各向同性噪声的高斯扩散过程来定义。在建立这种联系的过程中,我们弥合了逆散热和去噪扩散之间的差距,并阐明了这种建模选择导致的归纳偏差。最后,我们提出了一类广义扩散模型,它提供了标准高斯去噪扩散和逆散热的最佳效果,我们称之为模糊扩散模型。

2. 背景

2.1 扩散模型

java高斯扩散模型公式_人工智能

扩散模型的扩散过程可以定义为:java高斯扩散模型公式_人工智能_02,这里java高斯扩散模型公式_机器学习_03表示的是数据data,java高斯扩散模型公式_java高斯扩散模型公式_04是噪声隐变量(中间变量)。其中java高斯扩散模型公式_java高斯扩散模型公式_05单调递减,java高斯扩散模型公式_人工智能_06单调递增,数据的信息会被逐渐的被噪声摧毁。假设上述定义是马尔可夫的,那么转移分布可以写为java高斯扩散模型公式_算法_07,其中java高斯扩散模型公式_去噪_08java高斯扩散模型公式_机器学习_09

扩散模型的逆过程(去噪过程)在给定数据java高斯扩散模型公式_机器学习_03的情况下可以写为去噪分布
java高斯扩散模型公式_去噪_11
显然数据java高斯扩散模型公式_机器学习_03未知,通常利用一个神经网络来估计java高斯扩散模型公式_算法_13。此外,像DDPM中直接预测噪声java高斯扩散模型公式_算法_14来估计java高斯扩散模型公式_机器学习_03,或直接预测均值(java高斯扩散模型公式_人工智能_16)都是可以的。

最优化:模型对数似然的连续时间变分下界由以下对平方重构误差的期望给出,经验可得权重项为1效果不错。:
java高斯扩散模型公式_人工智能_17

2.2 Inverse Heat Dissipation(反向散热)

java高斯扩散模型公式_人工智能_18

通俗的讲,散热算法,如上图,可以理解为沸水的冷却,分子剧烈运动趋于平淡,或是人慢慢冷静的过程。上图可以看到的是图片中的高频成分都被平滑了,也就是模糊的过程。

逆散热模型(IHDM)使用散热来破坏信息,而不是增加越来越多的高斯噪声。他们观察到用于散热的拉普拉斯偏微分方程
java高斯扩散模型公式_去噪_19
可以用余弦变换频域的对角矩阵求解。让java高斯扩散模型公式_java高斯扩散模型公式_04表示在时间步java高斯扩散模型公式_算法_21的Laplace方程的解:
java高斯扩散模型公式_机器学习_22
其中java高斯扩散模型公式_机器学习_23表示离散余弦变换(DCT),java高斯扩散模型公式_人工智能_24表示逆DCT。java高斯扩散模型公式_去噪_25应该考虑在空间维度上向量化,以允许矩阵乘法。对角线矩阵java高斯扩散模型公式_java高斯扩散模型公式_26是频率java高斯扩散模型公式_人工智能_27和时间java高斯扩散模型公式_算法_21的加权矩阵的指数,因此java高斯扩散模型公式_去噪_29。这里是个结论,直接用就可以。因此可以定义扩散过程的边缘分布:
java高斯扩散模型公式_java高斯扩散模型公式_30
中间扩散状态java高斯扩散模型公式_java高斯扩散模型公式_04是通过对data逐步添加模糊和固定的噪声而构建的。对于生成过程,用一个可学习的生成模型近似地反转的散热过程:
java高斯扩散模型公式_人工智能_32
java高斯扩散模型公式_机器学习_33的均值是通过神经网络java高斯扩散模型公式_java高斯扩散模型公式_34直接学习的,并且具有固定的标量方差。类似于DDPMs,IHDM模型是通过采样随机时间步长java高斯扩散模型公式_算法_21的前向过程java高斯扩散模型公式_java高斯扩散模型公式_36,然后最小化模型之间的平方重建误差java高斯扩散模型公式_java高斯扩散模型公式_37和真实目标java高斯扩散模型公式_算法_38,产生以下训练损失:
java高斯扩散模型公式_人工智能_39
有些问题仍然存在: (1)散热过程是否是马尔可夫的,如果是,什么是java高斯扩散模型公式_算法_40?(2)真正的逆加热过程是否也是各向同性的,如生成过程?(3)最后,除了预测中前一个时间步长的均值之外,是否还有其他的方法呢?在下面的章节中,我们会发现: (1)是的,这个过程可以是马尔可夫的。(2)不,生成过程不是各向同性的,虽然它在频域上是对角性的。(3)是的,像散热这样的过程可以被参数化,类似于标准扩散模型中的环境参数化。

2.3 散热形式为高斯扩散

java高斯扩散模型公式_人工智能_41

我们将散热过程重新解释为一种高斯扩散的形式,类似于DDPM所使用的形式;Score-based Models等。本文将两个向量之间的乘法和除法定义为元素化的。我们从边际分布定义开始:
java高斯扩散模型公式_java高斯扩散模型公式_30
在本节中,我们让java高斯扩散模型公式_机器学习_23表示正交的DCT,这是DCT的一个特定的归一化设置。把变量java高斯扩散模型公式_机器学习_44写成这种形式,我们可以写出频率空间的扩散过程:
java高斯扩散模型公式_人工智能_45
其中java高斯扩散模型公式_java高斯扩散模型公式_46java高斯扩散模型公式_机器学习_03的频率相应,java高斯扩散模型公式_人工智能_48是对角线。这里java高斯扩散模型公式_算法_49表示关于时间步的噪声策略:线性或余弦等。上式表明,在每个维度java高斯扩散模型公式_去噪_50上,频率java高斯扩散模型公式_去噪_51的边缘分布在其标量元素java高斯扩散模型公式_算法_52上被完全分解。同样,逆向散热模型java高斯扩散模型公式_机器学习_53也被完全分解。我们可以等价地用每个维java高斯扩散模型公式_去噪_50的标量形式来描述散热过程(及其逆过程):
java高斯扩散模型公式_机器学习_55
这个方程可以看作是2.1节中介绍的标准高斯扩散过程的一个特殊情况。让java高斯扩散模型公式_人工智能_56表示为频域空间中的标准高斯扩散过程,那么java高斯扩散模型公式_算法_57。从概率论的角度来看,这里只有java高斯扩散模型公式_人工智能_58的比率很重要,而不是个体java高斯扩散模型公式_java高斯扩散模型公式_05java高斯扩散模型公式_人工智能_06的特殊选择。这是正确的,因为所有的值都可以简单地重新缩放,而无需以一种有意义的方式改变分布。这意味着,与其进行模糊处理和添加固定噪声,散热过程可以等效地定义为一个相对标准的高斯扩散过程,尽管是在频率空间。

马尔可夫转移分布:上述分析说明了与高斯扩散的等下性,因此散热模型同样也存在马尔可夫过程:
java高斯扩散模型公式_java高斯扩散模型公式_61
这与2.1中定义的是一样的,只是其中java高斯扩散模型公式_算法_62java高斯扩散模型公式_人工智能_63

去噪过程(逆过程)
java高斯扩散模型公式_机器学习_64
其中:
java高斯扩散模型公式_机器学习_65

3. Blurring Diffusion Models (模糊扩散模型)

java高斯扩散模型公式_java高斯扩散模型公式_66

利用第2.3节的分析,我们可以将该模型在频率空间中定义为一个高斯扩散模型,具有不同的维度schedules。如何参数化模型以及参数化java高斯扩散模型公式_java高斯扩散模型公式_05java高斯扩散模型公式_人工智能_06的具体schedules是很重要的。与传统模型不同,扩散过程是在频率空间中定义的:

java高斯扩散模型公式_机器学习_69

不同的频率可能以不同的速率扩散,这是由向量java高斯扩散模型公式_java高斯扩散模型公式_70java高斯扩散模型公式_人工智能_71中的值控制的。直接估计java高斯扩散模型公式_去噪_72对于神经网络来说是困难的。类似与DDPM中,学习java高斯扩散模型公式_机器学习_03的近似并带入去噪分布,通过参数化间接进行:

java高斯扩散模型公式_去噪_74

估计的java高斯扩散模型公式_人工智能_75。尽管在频率空间中表达扩散和去噪过程很方便,但神经网络往往在标准像素空间上运行良好。正是由于这个原因,神经网络java高斯扩散模型公式_java高斯扩散模型公式_34java高斯扩散模型公式_去噪_77作为输入,并预测java高斯扩散模型公式_算法_78。通过DCT矩阵就可以获得java高斯扩散模型公式_机器学习_79。将估计的java高斯扩散模型公式_机器学习_80带入到去噪分布的均值中可以得到:

java高斯扩散模型公式_机器学习_81

在像素空间无权重的目标函数可写为:

java高斯扩散模型公式_java高斯扩散模型公式_82

java高斯扩散模型公式_人工智能_83

3.1 噪声和模糊Schedules

为了精确地指定模糊过程,需要为java高斯扩散模型公式_java高斯扩散模型公式_84定义java高斯扩散模型公式_java高斯扩散模型公式_70java高斯扩散模型公式_人工智能_71的schedules。对于java高斯扩散模型公式_人工智能_71,我们为所有频率选择相同的值,因此给标量值java高斯扩散模型公式_人工智能_71一个schedule就足够了。这些shcedules是通过结合一个典型的高斯噪声扩散schedule(由标量在java高斯扩散模型公式_java高斯扩散模型公式_05java高斯扩散模型公式_人工智能_06处指定)和一个模糊schedule(由向量java高斯扩散模型公式_人工智能_91指定)来构造的。

噪声schedule:选择Variance Preserving Cosine Schedule,既java高斯扩散模型公式_去噪_92。为了避免在java高斯扩散模型公式_算法_93java高斯扩散模型公式_算法_94时的不稳定,对数信噪比java高斯扩散模型公式_人工智能_95java高斯扩散模型公式_机器学习_96时的最大值为java高斯扩散模型公式_人工智能_97,在java高斯扩散模型公式_java高斯扩散模型公式_98时至少为java高斯扩散模型公式_算法_99

模糊schedule: 没有理由使模型的概念时间步长java高斯扩散模型公式_算法_21与耗散时间完美匹配。因此,java高斯扩散模型公式_人工智能_101被重新定义,其中java高斯扩散模型公式_算法_49java高斯扩散模型公式_算法_21单调增加。变量java高斯扩散模型公式_算法_49在噪声扩散方面与java高斯扩散模型公式_java高斯扩散模型公式_05java高斯扩散模型公式_人工智能_06具有非常相似的函数:它允许相对于模型的概念时间步长java高斯扩散模型公式_算法_21的任意耗散调度。

尺度为java高斯扩散模型公式_去噪_108的高斯模糊对应于随时间java高斯扩散模型公式_算法_109的耗散。根据经验,下面的策略work well:
java高斯扩散模型公式_机器学习_110
其中,java高斯扩散模型公式_算法_111是可调的超参数,对应于图像最大模糊的程度。因此这个模糊schedule可以定义为java高斯扩散模型公式_java高斯扩散模型公式_112。如果仅仅使用java高斯扩散模型公式_去噪_113用于java高斯扩散模型公式_java高斯扩散模型公式_05,并同样用于步骤java高斯扩散模型公式_java高斯扩散模型公式_115,那么java高斯扩散模型公式_人工智能_116可能包含非常小的高频值。因此,一个不希望出现的副作用是,小的误差可能会被去噪过程中的许多步骤放大。因此,我们稍微修改了这个过程,让:
java高斯扩散模型公式_去噪_117
这里令java高斯扩散模型公式_人工智能_118。这种模糊变换将频率抑制到一个小的值dmin,同时去噪过程放大高频的侵略性较低。至此,结合高斯噪声schedule和模糊schedule,可以得到:
java高斯扩散模型公式_java高斯扩散模型公式_119
其中java高斯扩散模型公式_人工智能_120是值都为1的向量。

3.4 伪代码

# 计算alpha_t和sigma_t,见公式(a)
def get_alpha_sigma(t):
	dt = get_dt(t)
	a, sigma_t= get_noise_scaling_cosine(t)
	alpha_t = a * dt.
return alpha_t, sigma_t
# 计算d_t,见公式(b)
def get_dt(t, min_d=0.001):
	sigma_blur = sigma_blur_max * sin(t * pi / 2)^2
	dissipation_t = sigma_blur^2 / 2
	freq = pi * linsapce(0, img_dim-1, img_dim) / img_dim
	labda = freqs[None, :, None, None]^2 + freqs[None, None, :, None]^2
	dt = (1-min_d) * exp(-labda * dissipation_t) + min_d
	return dt
# t~[0,1], 见3.1噪声Schedule部分
def get_noise_schaling_cosine(t, logsnr_min=-10, logsnr_max=10):
	limit_max = arctan(exp(-0.5 * logsnr_max))
	limit_min = arctan(exp(-0.5 * logsnr_min)) - limit_max
	logsnr = -2 * log(tan(limit_min * t + limit_max))
	# Transform logsnr to a, sigma .
	return sqrt(sigmoid(logsnr)), sqrt(sigmoid(-logsnr))
# 前向过程, 见公式(c)
def diffuse(x, t):
	x_freq=DCT(x)
	alpha, sigma = get_alpha_sigma(t)
	eps = random_normal_like(x)
	z_t = IDCT(alpha * x_freq) + sigma * eps
return z_t , eps
# 损失函数,见公式(d)
def loss(x):
	t = random_uniform(0, 1)
	z_t, eps = diffuse(x, t)
	error = (eps - neural_net(z_t, t))^2
	return mean(error)
# 采样过程,t=T,T-1,...,1/T, 见公式(24)(19)
def denoise(z_t, t, delta =1e-8):
	alpha_s, sigma_s = get_alpha_sigma(t-1 / T)
	alpha_t, sigma_t = get_alpha_sigma (t)
	# Compute helpful coefficients .
	alpha_ts = alpha_t / alpha_s
	alpha_st = 1 / alpha_ts
	sigma2_ts = ( sigma ^2 - alpha_ts ^2 * sigma_s ^2)
	# Denoising variance .
	sigma2_denoise = 1 / clip (
	1 / clip ( sigma_s ^2 , min= delta ) +
	1 / clip ( sigma_t ^2 / alpha_ts ^2 - sigma_s ^2 , min= delta ) ,
	min = delta )
	# The coefficients for u_t and u_eps .
	coeff_term1 = alpha_ts * sigma2_denoise / ( sigma2_ts + delta ) # eq.24
	coeff_term2 = alpha_st * sigma2_denoise / clip(sigma_s ^2, min= delta)
	# Get neural net prediction .
	hat_eps = neural_net(z_t, t)
	# Compute terms .
	u_t = DCT(z_t)
	term1 = IDCT(coeff_term1 * u_t)
	term2 = IDCT(coeff_term2 * (u_t - sigma_t * DCT(hat_eps)))
	mu_denoise = term1 + term2
	# Sample from the denoising distribution .
	eps = random_normal_like(mu_denoise)
return mu_denoise + IDCT (sqrt(sigma2_denoise) * eps)

4. 实验

java高斯扩散模型公式_java高斯扩散模型公式_121


CIFAR10 and LSUN Churches: 为了测量生成样本的视觉质量,在经过200万步训练后,从模型中采样的50000个样本计算FID评分。从这些分数中可以看出(表1),模糊的扩散模型能够生成比ihdm的质量更高的图像,以及文献中其他类似的方法。我们的模糊扩散模型也优于标准的去噪扩散模型,尽管在这种情况下,性能上的差异不那么明显。

4.1 在不同的噪音水平和Schedules之间的比较

java高斯扩散模型公式_java高斯扩散模型公式_122


在本节中,我们分析了在最大模糊java高斯扩散模型公式_算法_111和两种不同的噪声schedule(java高斯扩散模型公式_java高斯扩散模型公式_124java高斯扩散模型公式_算法_125)方面有不同的设置。其中java高斯扩散模型公式_机器学习_126等价于一个标准的去噪扩散模型。对于CIFAR10,性能最好的模型使用的模糊值为java高斯扩散模型公式_java高斯扩散模型公式_127,其FID为3.17,优于没有应用模糊时的FID3.60,如表3所示。对于LSUN,表现最好的模型使用了更少的模糊java高斯扩散模型公式_算法_128,性能再次相对接近于java高斯扩散模型公式_java高斯扩散模型公式_127的模型。当将java高斯扩散模型公式_java高斯扩散模型公式_124java高斯扩散模型公式_算法_125 Schedule进行比较时,对于更高的最大模糊,java高斯扩散模型公式_java高斯扩散模型公式_124表现得更好。我们的假设是,java高斯扩散模型公式_算法_125 模糊过于激进,而java高斯扩散模型公式_java高斯扩散模型公式_124的图在java高斯扩散模型公式_机器学习_96附近的扩散过程开始时逐渐增加模糊。

java高斯扩散模型公式_去噪_136


有趣的是,具有较高最大模糊的模型收敛较慢,但当训练时间足够长时,比具有较少模糊的模型表现更好。比较最大模糊1和20的设定时,只有当20的设定训练超过一定步数才会比1的设定要好。似乎更高的模糊需要更多的时间来训练,但可以学会更好地匹配数据。

结论

模糊扩散模型的一个局限性是,模糊的使用具有正则化效果:当使用模糊时,将生成模型训练到收敛需要更长的时间。例如,正则化效应通常是有益的,并可以导致样本质量的提高。