扩散模型(diffusion model)

扩散过程

对初始数据分布python高斯扩散模型代码 高斯扩散模型公式推导_数据分布~q(x),不断添加高斯噪声,最终使数据分布python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_02变成各项独立的高斯分布。

  • 前向扩散过程的定义
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_03
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_04(马尔科夫链过程)
  • 通过重参数化技巧,可以推导出任意时刻的python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_05,无需做迭代
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_06
    其中python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_07;参数重整化体现为python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_08中,python高斯扩散模型代码 高斯扩散模型公式推导_参数化_09为两个正态分布叠加,可以重参数化为python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_10
  • 每个时间步所添加的噪声的标准差python高斯扩散模型代码 高斯扩散模型公式推导_参数化_11给定,且随t增大而增大
  • 每个时间步所添加的噪声的均值与python高斯扩散模型代码 高斯扩散模型公式推导_参数化_11有关,为了使python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_13稳定收敛到python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_14
  • python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_15可得
  • python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_16
  • 随着不断加噪,python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_17逐渐接近纯高斯噪声
  • python高斯扩散模型代码 高斯扩散模型公式推导_参数化_18
  • 扩散过程中的后验条件概率python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_19可以用公式表达,即给定python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_20python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_21,可计算出python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_22
    假设python高斯扩散模型代码 高斯扩散模型公式推导_参数化_11足够小时,python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_24
    由高斯分布的概率密度函数python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_25和贝叶斯可得
    python高斯扩散模型代码 高斯扩散模型公式推导_参数化_26
    由二次函数的均值和方差计算可得
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_27
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_28(DDPM作者使用python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_29,认为两者结果近似)
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_21的公式代入得(python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_31使用了重参数化)
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_32
    即在python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_21条件下,后验条件概率分布可通过python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_20python高斯扩散模型代码 高斯扩散模型公式推导_参数化_35计算得到

逆扩散过程

从高斯噪声python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_36中逐步还原出原始数据python高斯扩散模型代码 高斯扩散模型公式推导_数据分布。马尔科夫链过程。

  • python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_38
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_39

目标函数

对负对数似然python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_40使用变分下限(VLB),并进一步推导化简得到最终loss

  • python高斯扩散模型代码 高斯扩散模型公式推导_参数化_41
  • 在推导的过程中,loss转换为python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_42python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_43两个高斯分布之间的KL散度,将python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_44python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_20的公式代入将loss转化为python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_46python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_21python高斯扩散模型代码 高斯扩散模型公式推导_参数化_48的公式
  • DDPM作者采用了预测随机变量(噪声)法,并不直接预测后验分布的期望值或原始数据
  • DDPM作者将方差python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_49用给定的python高斯扩散模型代码 高斯扩散模型公式推导_参数化_11python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_51代替,训练参数只存在均值中,为了使训练更加稳定

训练过程

  1. 给出原始数据python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_52
  2. 设定python高斯扩散模型代码 高斯扩散模型公式推导_参数化_53
  3. 从标准高斯分布采样一个噪声python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_54
  4. 采用梯度下降法优化目标函数python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_55

推断过程

  1. 每个时间步通过python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_20python高斯扩散模型代码 高斯扩散模型公式推导_参数化_48计算python高斯扩散模型代码 高斯扩散模型公式推导_参数化_58
    均值python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_59,方差python高斯扩散模型代码 高斯扩散模型公式推导_参数化_60
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_61
  2. 通过重参数从python高斯扩散模型代码 高斯扩散模型公式推导_参数化_58中采样得到python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_22
  3. 通过不断迭代最终得到python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_21

代码实现

  • 定义时间步数、python高斯扩散模型代码 高斯扩散模型公式推导_参数化_11python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_66等公式计算中需要用到的常量
    python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_15
    python高斯扩散模型代码 高斯扩散模型公式推导_python高斯扩散模型代码_68
    python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_69
    DDPM论文中作者将时间步数python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_70设置为1000,python高斯扩散模型代码 高斯扩散模型公式推导_参数化_71为0.0001到0.02之间的线性插值
num_timesteps = 1000
schedule_low = 1e-4
schedule_high = 0.02
betas = torch.tensor(np.linspace(schedule_low, schedule_high, num_timesteps), dtype=torch.float32)

alphas = 1 - betas
alphas_cumprod = np.cumprod(alphas)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = np.sqrt(1 - alphas_cumprod)
reciprocal_sqrt_alphas = np.sqrt(1 / alphas)
betas_over_sqrt_one_minus_alphas_cumprod = (betas / sqrt_one_minus_alphas_cumprod)
sqrt_betas = np.sqrt(betas)
  • 前向扩散过程
    python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_15
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_73
def forward_diffusion_process(model, x0, num_timesteps, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
    batch_size = x0.shape[0]
    t = torch.randint(0, num_timesteps, size=(batch_size,))
    noise = torch.randn_like(x0)
    xt = sqrt_alphas_cumprod[t] * x0 + sqrt_one_minus_alphas_cumprod[t] * noise
    estimated_noise = model(xt, t)
    loss = (noise - estimated_noise).square().mean()
    return loss
  • 逆向扩散过程
    python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_74
    python高斯扩散模型代码 高斯扩散模型公式推导_数据分布_75
    python高斯扩散模型代码 高斯扩散模型公式推导_马尔科夫链_69
def reverse_diffusion_process(model, shape, num_timesteps, reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas):
    current_x = torch.randn(shape)
    x_seq = [current_x]
    for t in reversed(range(num_timesteps)):
        current_x = sample(model, current_x, t, shape[0], reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas)
        x_seq.append(current_x)
    return x_seq

def sample(model, xt, t, batch_size, reciprocal_sqrt_alphas, betas_over_sqrt_one_minus_alphas_cumprod, sqrt_betas):
    ts = torch.full([batch_size, 1], t)
    estimated_noise = model(xt, ts)
    mean = reciprocal_sqrt_alphas[ts] * (xt - betas_over_sqrt_one_minus_alphas_cumprod[ts] * estimated_noise)
    if t > 0:
        z = torch.randn_like(xt)
    else:
        z = 0
    sample = mean + sqrt_betas[t] * z
    return sample