扩散模型(diffusion model)
扩散过程
对初始数据分布~q(x),不断添加高斯噪声,最终使数据分布变成各项独立的高斯分布。
- 前向扩散过程的定义
(马尔科夫链过程) - 通过重参数化技巧,可以推导出任意时刻的,无需做迭代
其中;参数重整化体现为中,为两个正态分布叠加,可以重参数化为 - 每个时间步所添加的噪声的标准差给定,且随t增大而增大
- 每个时间步所添加的噪声的均值与有关,为了使稳定收敛到
- 由可得
- 随着不断加噪,逐渐接近纯高斯噪声
- 扩散过程中的后验条件概率可以用公式表达,即给定、,可计算出
假设足够小时,
由高斯分布的概率密度函数和贝叶斯可得
由二次函数的均值和方差计算可得
(DDPM作者使用,认为两者结果近似)
将的公式代入得(使用了重参数化)
即在条件下,后验条件概率分布可通过和计算得到
逆扩散过程
从高斯噪声中逐步还原出原始数据。马尔科夫链过程。
目标函数
对负对数似然使用变分下限(VLB),并进一步推导化简得到最终loss
- 在推导的过程中,loss转换为与两个高斯分布之间的KL散度,将与的公式代入将loss转化为、、的公式
- DDPM作者采用了预测随机变量(噪声)法,并不直接预测后验分布的期望值或原始数据
- DDPM作者将方差用给定的或代替,训练参数只存在均值中,为了使训练更加稳定
训练过程
- 给出原始数据
- 设定
- 从标准高斯分布采样一个噪声
- 采用梯度下降法优化目标函数
推断过程
- 每个时间步通过和计算
均值,方差 - 通过重参数从中采样得到
- 通过不断迭代最终得到
代码实现
- 定义时间步数、、等公式计算中需要用到的常量
DDPM论文中作者将时间步数设置为1000,为0.0001到0.02之间的线性插值
num_timesteps = 1000
schedule_low = 1e-4
schedule_high = 0.02
betas = torch.tensor(np.linspace(schedule_low, schedule_high, num_timesteps), dtype=torch.float32)
alphas = 1 - betas
alphas_cumprod = np.cumprod(alphas)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1 - alphas_cumprod)
reciprocal_sqrt_alphas = np.sqrt(1 / alphas)
betas_over_sqrt_one_minus_alphas_cumprod = (betas / sqrt_one_minus_alphas_cumprod)
sqrt_betas = np.sqrt(betas)
- 前向扩散过程
def forward_diffusion_process(model, x0, num_timesteps, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
batch_size = x0.shape[0]
t = torch.randint(0, num_timesteps, size=(batch_size,))
noise = torch.randn_like(x0)
xt = sqrt_alphas_cumprod[t] * x0 + sqrt_one_minus_alphas_cumprod[t] * noise
estimated_noise = model(xt, t)
loss = (noise - estimated_noise).square().mean()
return loss
- 逆向扩散过程
def reverse_diffusion_process(model, shape, num_timesteps, reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas):
current_x = torch.randn(shape)
x_seq = [current_x]
for t in reversed(range(num_timesteps)):
current_x = sample(model, current_x, t, shape[0], reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas)
x_seq.append(current_x)
return x_seq
def sample(model, xt, t, batch_size, reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas):
ts = torch.full([batch_size, 1], t)
estimated_noise = model(xt, ts)
mean = reciprocal_sqrt_alphas[ts] * (xt - betas_over_sqrt_one_minus_alphas_cumprod[ts] * estimated_noise)
if t > 0:
z = torch.randn_like(xt)
else:
z = 0
sample = mean + sqrt_betas[t] * z
return sample