主要对 BFN 核心部分的代码实现进行详细的解析以及介绍了工程实现上的经验性问题-分布式训练时随机种子的设置

原本计划上一篇就会 kill 掉整个系列,因为不确定作者是否会开源。于是在这忐忑之期就到北京溜达了一圈,回来后惊喜(恐)地发现作者还真的开源了!“喜”当然是因为我终究能够续上一直以来在文章中贯彻的“不无聊风格”——源码解析;至于“恐”嘛~ 就是我又得费脑和费手指了..

本文是整个系列的终结篇(CW 很认真,不开玩笑!),主要内容是对 BFN 核心部分的代码实现进行解析,主要包括(按顺序):模型输入输出并计算 loss、采样生成样本(生成模型的天职)、BFN 的核心——贝叶斯流(bayesian flow)的实现、模型训练的武功秘籍——损失函数的实现、BFN 建模的关键——输出分布的实现、神经网络(model)本身的实现、数据加载和预处理、整体训练流程 以及 一个工程上实现的问题——分布式训练时随机种子的设置。

在最后一章,CW 难免要吹吹水,于是先简单总结下 BFN 的玩法,然后将其与扩散模型进行比较,最后发自内心谈谈自己对这个方法论的理解与看法。

如果仅关注 BFN 算法本身的代码实现,那么可以只看前五章;否则,如果你连 BFN 本身是什么都不知道(请问您是怎么进来的..),那么就直接跳转到最后一章吧,或许能令你对 BFN 有个浅浅的认识(并体会到不无聊的风格);又或者,你不小心手抖点进来了,也可以看看倒数第二章,即第九章,那是个工程实现上的经验性问题,只要是在 Pytorch 的分布式框架下玩都适用;若这些情况都不是,那么 CW 懂了——你是要看完全文!辛苦了您,感恩~!

附:BFN 官方源码:https://github.com/nnaisense/bayesian-flow-networks

一、Loss 计算

作者将 loss 的计算流程封装在了 BFN 这个类里,同时,还在其中封装了采样生成的过程。所以,要注意在代码实现中,这个类并非代表神经网络(model)本身的实现,而是 BFN work 的整体逻辑:loss 计算对应训练过程、采样生成则对应推理过程。

整体流程

class BFN(nn.Module):
    def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):
        super().__init__()
        
        self.net = net
        self.bayesian_flow = bayesian_flow
        self.loss = loss

    def forward(
        self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None
    ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:
        """
        Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.
        t is sampled randomly if None. If t is not None, expect t.shape == data.shape.
        
        使用蒙特卡洛方法估计发送者分布和接收者分布之间的 KL 散度损失:
        -采样时间变量;
        -从贝叶斯流分布中采样得到输入分布的参数(后验更新);
        -将输入分布的参数喂给模型;
        -模型返回输出分布;
        -计算连续/离散时间 loss.
        """

        t = self.sample_t(data, n_steps) if t is None else t
        
        # sample input parameter flow
        # 从贝叶斯流分布中采样出输入分布的参数(代表已完成后验更新).
        input_params = self.bayesian_flow(data, t)
        # 在输入模型前转换为适合于模型输入的形式(如有必要的话)
        net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)
        # compute output distribution parameters
        # 注意, 这里模型输出的通常不是输出分布的参数, 而是某些变量(比如估计的噪声),
        # 它们经过后处理才最终成为输出分布的参数.
        output_params: Tensor = self.net(net_inputs, t)

        # compute KL loss in float32
        with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False):
            if n_steps == 0 or n_steps is None:
                loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)
            else:
                loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)

        # loss shape is (batch_size, 1)
        return loss.mean()

loss 计算的整个流程 CW 已在上述注释中写明。在连续时间的情况下是不需要指定总时间步 n_steps 的,因此当 n_steps = 0 或未指定时就使用连续时间的损失函数进行计算;否则,就使用离散时间的损失函数。至于损失函数的实现,后文会详细解析。

Bayesian Flow Networks(BFN)合集5_建模

以上前向过程 forward() 的第一步就是采样出时间变量,下面来看看这一步的具体实现。

时间变量的采样

@staticmethod
    @torch.no_grad()
    def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:
        """采样时间变量 t, 包括连续时间和离散时间两种情况."""
        
        # 连续时间情况不需要指定总步数, 从 U(0,1) 连续型均匀分布中采样.
        if n_steps == 0 or n_steps is None:
            # (B,1)
            t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)
        # 离散时间情况则先从 U{0,n-1} 离散型均匀分布采样出时间步,然后再除总步数 n 计算出对应的时间变量值: t = \frac{i-1}{n}
        # 注意, 这是每个区间起始时刻的值.
        else:
            # (B,1)
            t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps
        # 扩展至和数据同样的维度, 不同的数据样本的时间变量不一致, 同一个样本内所有维度上所对应的时间变量则相同.
        t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)
        
        return t

这个 sample_t() 方法也是封装在 BFN 这个类里的,但从程序设计的逻辑上来看,它并不专属于某个特定的类,而是可以作为通用方法来使用的,因此用 @staticmethod 修饰器使其成为静态方法。

时间变量仅在不同数据样本之间存在差异,而同一个样本在所有维度上都应该拥有相同的时间变量值,于是采样的时间变量个数只需与数据样本的数量相等即可,这个数对应于 batch_size,也就是 data.size(0),采样完成后再将维度扩充至与数据相同。

二、采样生成

以下是采样生成样本的过程,整体可概括为:

  1. 设置先验参数 input_params
  2. 根据当前时间步计算对应的时间变量 t
  3. 将先验和时间变量输入模型令其返回输出分布的参数 output_params
  4. 从输出分布中采样,采样结果当作当前步骤的生成样本 output_sample
  5. 根据当前时间步计算出对应的精度 alpha
  6. 以输出分布的样本和精度为参数,从发送者分布中采样出观测样本 y
  7. 利用观测样本根据贝叶斯更新函数(贝叶斯定理)计算后验,从而对先验进行更新 update_input_params(...)
  8. 不断重复 2~7,待完成至规定的总步数 n_steps 后(那时 t=1)再根据 3 ~ 4 生成最终的样本
@torch.inference_mode()
    def sample(self, data_shape: tuple, n_steps: int) -> Tensor:
        device = next(self.parameters()).device
        
        # 起始时刻的先验
        input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)
        distribution_factory = self.loss.distribution_factory

        for i in range(1, n_steps):
            # t_{i-1} = \frac{i-1}{n}
            t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps
            
            # 模型接收输入分布的参数并预测,形成输出分布的参数后,再从其中采样作为预测(生成)的数据样本.
            output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
            output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()
            output_sample = output_sample.reshape(*data_shape)
            
            # 计算精度 \alpha_i
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            # 采样观测样本
            y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()
            # 后验更新
            input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)

        # 最后时刻 t=1
        t = torch.ones(*data_shape, device=device)
        output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)
        # 概率分布的众数(mode)作为样本.
        output_sample = distribution_factory.get_dist(output_params, input_params, t).mode
        output_sample = output_sample.reshape(*data_shape)
        
        return output_sample

以上需要特别注意的是,最终是使用输出分布的 mode:也就是一个概率分布的众数(概率最大处所对应的样本,即最有可能出现的结果)作为生成结果;而在前面迭代的过程中,使用的是输出分布的常规采样结果作为当前步骤生成的样本。

三、贝叶斯流的实现

贝叶斯流的目标是计算后验,从而对先验进行更新。但与基于贝叶斯定理来计算后验的单步更新不同,它能够根据原始数据样本和任意时间变量计算出对应时刻的后验,而不依赖于由起始时刻至今过程中的那些观测样本。

作者实现了一个抽象基类 BayesianFlow,其中定义了贝叶斯流会用到的一些方法(定义为抽象方法 abstractmethod),而建模不同类型数据时所对应的贝叶斯流都要继承这个基类,并且将抽象方法都真正地实现(overwrite)。

class BayesianFlow(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:
        """Returns the initial input params (for a batch) at t=0. Used during sampling.
        For discrete data, the tuple has length 1 and contains the initial class probabilities.
        For continuous data, the tuple has length 2 and contains the mean and precision.
        
        返回起始时刻的先验参数, 作为模型的输入, 方法用于采样过程的开端."""
        pass

    @abstractmethod
    def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:
        """Utility method to convert input distribution params to network inputs if needed.
        
        如果有必要的话, 将输入分布的参数转换为适合模型输入的形式.
        比如在建模离散化数据时, 输入分布的参数代表概率, 取值范围在[0,1], 于是在输入模型前会将其 scale 至[-1,1],
        从而与其他类型的数据场景兼容, 并且避免让模型永远只接收非负值."""
        pass

    @abstractmethod
    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:
        """Returns the alpha at step i of total n_steps according to the flow schedule. Used:
        a) during sampling, when i and alpha are the same for all samples in the batch.
        b) during discrete time loss computation, when i and alpha are different for samples in the batch.
        
        计算某个离散时间步所对应的精度: \alpha_i = \beta(t_i) - \beta(t_{i-1}), 用于采样过程或离散时间的损失函数. """
        pass

    @abstractmethod
    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        """Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:
        a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.
        b) during discrete time loss computation when alpha are different for samples in the batch.
        
        返回指定精度 \alpha 下的输入分布. """
        pass

    @abstractmethod
    def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:
        """Updates the distribution parameters using Bayes' theorem in light of noisy sample y.
        Used during sampling when alpha is the same for the whole batch.
        
        根据贝叶斯定理利用观测样本 y 计算后验, 从而更新先验. """
        pass

    @abstractmethod
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:
        """Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.
        Used during training when t (and thus accuracies) are different for different samples in the batch.
        For discrete data, the returned tuple has length 1 and contains the class probabilities.
        For continuous data, the returned tuple has length 2 and contains the mean and precision.
        
        从贝叶斯流分布中采样得到后验, 代表对输入分布参数的更新. """
        pass

由上可以看到,作者将单步更新的贝叶斯更新函数也封装在了 BayesianFlow 这个类里,可能考虑到两者的目标都一致吧(都是计算后验、更新先验)。

建模连续和离散化数据的贝叶斯流被作者实现为 CtsBayesianFlow 类,因为离散化数据就是由连续数据经过离散化操作而得到的,所以两者共用一套逻辑;而建模离散数据的贝叶斯流则实现为 DiscreteBayesianFlow 类。

接下来,我们就分别深入到两者的内部去一探究竟吧~!

建模连续和离散化数据

在建模连续和离散化数据时,贝叶斯流分布为:

Bayesian Flow Networks(BFN)合集5_人工智能_02

class CtsBayesianFlow(BayesianFlow):
    """建模连续/离散化数据的贝叶斯流."""
    
    def __init__(
        self,
        min_variance: float = 1e-6,
    ):
        super().__init__()
        self.min_variance = min_variance

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:
        """返回贝叶斯流分布的采样结果, 即经过后验更新的输入分布的均值向量: \mu."""
        
        # \omega_1^{2t}
        post_var = torch.pow(self.min_variance, t)
        # \gamma(t)
        alpha_t = 1 - post_var
        # \gamma(t)(1-\gamma(t))
        mean_var = alpha_t * post_var
        
        # 贝叶斯流分布的均值: \gamma(t)x
        mean_mean = alpha_t * data
        # 贝叶斯流分布的标准差: \sqrt{\gamma(t)(1-\gamma(t))}
        mean_std_dev = mean_var.sqrt()
        
        # 标准高斯噪声
        noise = torch.randn(mean_mean.shape, device=mean_mean.device)
        # 利用重参数化技术构造贝叶斯流分布的样本
        mean = mean_mean + (mean_std_dev * noise)
        
        # We don't need to compute the variance because it is not needed by the network, so set it to None
        input_params = (mean, None)
        
        return input_params

另外,以上并非直接从贝叶斯流分布中进行采样,而是使用了重参数化技术——先从标准正态分布中采样出高斯噪声,然后再通过 scale & shift 获得目标分布的采样结果,即:

Bayesian Flow Networks(BFN)合集5_人工智能_03

这个过程对应于以下代码中的 update_input_params() 方法。

至于其它就比较琐碎且简单了,各位客官自行看下面代码即可:

def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        # 仅取输入分布的均值向量作为 BFN 的输入
        # Only the mean is used by the network
        return params[0]

    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:
        # 起始时刻的先验是标准高斯分布, 均值为0, 方差为1(协方差矩阵是对角元均为1的对角阵)
        return torch.zeros(*data_shape, device=device), 1.0

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        # 根据 \beta(t_i) - \beta(t_{i-1}) 计算, 其中 t_i = \frac{i}{n}.
        sigma_1 = math.sqrt(self.min_variance)
        return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        # 返回输入分布, 精度 \alpha 是方差的倒数.
        dist = D.Normal(x, 1.0 / alpha**0.5)
        return dist

    def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:
        """贝叶斯更新函数, 对输入分布的参数进行后验更新."""
        
        input_mean, input_precision = input_params
        # \rho_i = \rho_{i-1} + \alpha
        new_precision = input_precision + alpha
        # 根据贝叶斯定理计算: \mu_i = \frac{ \rho_{i-1} \mu_{i-1} + \alpha y }{\rho_i}
        new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision
        
        return new_mean, new_precision

建模离散数据

与连续和离散化数据的场景不同,当建模离散数据时,贝叶斯流分布则为:

Bayesian Flow Networks(BFN)合集5_数据_04

class DiscreteBayesianFlow(BayesianFlow):
    def __init__(
        self,
        n_classes: int,
        min_sqrt_beta: float = 1e-10,
        discretize: bool = False,
        epsilon: float = 1e-6,
        max_sqrt_beta: float = 1,
    ):
        super().__init__()
        
        # K
        self.n_classes = n_classes
        # 一个极小值, 用于将传入贝叶斯流分布的时间变量最大值限制至 1-epsilon.
        # 因为贝叶斯流分布是用于最终时刻前的, 所以需要 t < 1.
        self.epsilon = epsilon
        
        # 是否进行离散化操作
        self.discretize = discretize
        
        # \sqrt{\beta} 的下限
        self.min_sqrt_beta = min_sqrt_beta
        # \sqrt{\beta(1)}
        self.max_sqrt_beta = max_sqrt_beta
        
        # 均匀分布的期望熵: H = - \sum_{i=1}^K{p(x_i)ln(p(x_i))}, p(x_i)=\frac{1}{K}
        self.uniform_entropy = math.log(self.n_classes)

    @torch.no_grad()
    def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:
        """根据贝叶斯流分布完成后验更新."""
        
        if self.discretize:
            # 若要进行离散化操作, 则将数据以对应的离散化区间索引表示.
            data = float_to_idx(data, self.n_classes)
        
        # \sqrt{\beta(t)}
        sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))
        lo_beta = sqrt_beta < self.min_sqrt_beta
        sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)
        # \beta(t)
        beta = sqrt_beta.square().unsqueeze(-1)
        
        # 从精度参数为 \beta(t) 的发送者分布中采样观测样本以作为贝叶斯流分布的 logits.
        logits = self.count_sample(data, beta)
        probs = F.softmax(logits, -1)
        # 将精度太小的部分所对应的后验以均匀先验 \frac{1}{K} 代替.
        # 这是因为精度太小, 那么对应的观测样本也"不靠谱"——所包含真实数据的信息太少,
        # 将其作为 logits 就不靠谱, 即以此为根据而实现的后验更新意义不大.
        probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)
        if self.n_classes == 2:
            # 如果是二分类则只取其中一类的概率即可.
            probs = probs[..., :1]
            probs = probs.reshape_as(data)
            
        input_params = (probs,)
        
        return input_params

    def t_to_sqrt_beta(self, t):
        """计算当前时刻的 accuracy schedule: \beta(t) 的开根:
           sqrt{\beta(t)} = t \sqrt{\beta(1)}."""
        
        return t * self.max_sqrt_beta

    def count_dist(self, x, beta=None) -> D.Distribution:
        """贝叶斯流分布中的期望部分所对应的发送者分布."""

        # Ke_x - 1
        mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1
        # \sqrt{K}
        std_dev = math.sqrt(self.n_classes)
        
        if beta is not None:
            # \beta(t)(Ke_x - 1)
            mean = mean * beta
            # \sqrt{\beta(t)K}
            std_dev = std_dev * beta.sqrt()
            
        return D.Normal(mean, std_dev, validate_args=False)

    def count_sample(self, x, beta):
        """利用重参数化采样技术(rsample())采样出观测样本作为贝叶斯流分布的 logits 源(下一步将其输入 softmax 以实现后验更新)."""
        return self.count_dist(x, beta).rsample()

利用贝叶斯流分布更新先验的整个过程即以上的前向过程 forward(),其中代码对应的释义 CW 都已详细注解。在上面的代码实现中,需要注意的细节有几个:

Bayesian Flow Networks(BFN)合集5_离散化_05

关于以上最后一点,通常使用重参数化采样是因为要使得梯度流能通过要学习的参数,但是以上这部分却没有需要学习的参数,之所以还这样做可能是考虑到在高维空间中从标准高斯分布中采样会相对高效。另外,对数据进行离散操作的 float_to_idx() 方法会在后文“数据加载与预处理”那章进行解析。

与前面一节建模连续和离散化数据时一样,这个 DiscreteBayesianFlow 类还封装了许多有用的方法,比如:

  • 根据贝叶斯定理来计算后验的贝叶斯更新函数

Bayesian Flow Networks(BFN)合集5_数据_06

@torch.no_grad()
    def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:
        """初始先验: 各类别概率相等的均匀分布 U{1, K}."""
        
        # 注意返回的是元组, 这是为了与连续/离散化数据的场景保持一致性.
        return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)

    @torch.no_grad()
    def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:
        params = params[0]
        if self.n_classes == 2:
            # 作者使用的 MNIST 数据集是经过二值化处理的, 因此这部分针对 MNIST 操作,
            # 将模型输入的范围缩放至 [-1,1]
            params = params * 2 - 1  # We scale-shift here for MNIST instead of in the network like for text
            # 因为总共只有两个类别, 所以取其中一类所对应的概率即可.
            params = params[..., :1]
            
        return params

    def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:
        # 计算离散时间步所对应的精度: \alpha_i = \beta(1) \frac{2i-1}{n^2}
        return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)

    def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:
        e_x = F.one_hot(x.long(), self.n_classes)
        alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha
        dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)
        
        return dist

    def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:
        """贝叶斯更新函数: 利用贝叶斯定理计算后验."""
        
        new_input_params = input_params[0] * y.exp()
        new_input_params /= new_input_params.sum(-1, keepdims=True)
        
        # 注意返回的是元组
        return (new_input_params,)

四、损失函数的实现

作者将损失函数封装在了一个抽象基类 Loss 里,其中包含了三种具体的 loss 计算:连续时间下发送者分布和接收者分布的 KL loss、离散时间下发送者分布和接收者分布的 KL loss 以及 实际不参与训练的重构 loss。

无论是针对 连续、离散化 亦或是 离散数据的损失函数,都要继承这个基类,并实现以上三种 loss 计算的逻辑。

class Loss(nn.Module, ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:
        """Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        连续时间的损失函数. """
        pass

    @abstractmethod
    def discrete_time_loss(
        self, data: Tensor,
        output_params: Tensor, input_params: Tensor,
        t: Tensor, n_steps: int, n_samples: int = 20
    ) -> Tensor:
        """Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using
        n_samples for Monte Carlo estimation of the discrete loss.
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        离散时间的损失函数, 当所需计算的 KL 散度没有解析形式时, 使用蒙特卡洛方法来近似估计. """
        pass

    @abstractmethod
    def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        """Returns the reconstruction loss, i.e. the final cost of transmitting clean data.
        The input params are only used when the network is parameterized to predict the noise for continuous data.
        
        重构损失, 不参与训练. """
        pass

连续和离散化数据的损失函数

连续和离散化数据的 loss 计算共用一套逻辑,被封装在 CtsBayesianFlowLoss 这个类中。

def sandwich(x: Tensor):
    return x.reshape(x.size(0), -1, x.size(-1))


class CtsBayesianFlowLoss(Loss):
    """建模连续/离散化数据场景时所用的损失函数, 包括:
    -离散时间损失函数;
    -连续时间损失函数;
    -重构损失"""
    
    def __init__(
        self,
        bayesian_flow: CtsBayesianFlow,
        distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],
        min_loss_variance: float = -1,
        noise_pred: bool = True,
    ):
        super().__init__()
        
        self.bayesian_flow = bayesian_flow
        # 返回输出分布的工厂对象
        self.distribution_factory = distribution_factory
        # \sigma_1^{2} 的下限, 以防用作分母时溢出.
        self.min_loss_variance = min_loss_variance
        # -ln(\sigma_1)
        self.C = -0.5 * math.log(bayesian_flow.min_variance)
        
        # 是否预测噪声(亦或是直接预测数据)
        self.noise_pred = noise_pred
        if self.noise_pred:
            self.distribution_factory.log_dev = False
            # 在预测噪声的情况下, 将预测的噪声(或噪声分布相关的参数)转换为对应数据分布(输出分布)的参数.
            self.distribution_factory = PredDistToDataDistFactory(
                self.distribution_factory, self.bayesian_flow.min_variance
            )

CW 在本系列的第二三篇文章中解析过,在建模连续和离散化数据时,模型预测的分别是高斯噪声 和 高斯噪声分布相关的参数:均值  和 对数标准差 ,因此需要对模型的输出进行一些后处理,以便将其转换为目标数据分布(输出分布)所对应的参数。

  • 连续时间的损失函数

Bayesian Flow Networks(BFN)合集5_离散化_07

def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        # 模型输出
        # reshape 成3维:(B, -1, D)
        output_params = sandwich(output_params)
        
        t = t.flatten(start_dim=1).float()
        flat_target = data.flatten(start_dim=1)
        
        # \sigma_1^{2t}
        posterior_var = torch.pow(self.bayesian_flow.min_variance, t)
        if self.min_loss_variance > 0:
            # 做最小值截断, 以防其作分母时防止溢出
            posterior_var = posterior_var.clamp(min=self.min_loss_variance)
        
        # 输出分布
        pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        # 输出分布的均值 E[P(\theta, t)]
        pred_mean = pred_dist.mean
        
        mse_loss = (pred_mean - flat_target).square()
        # 连续时间的损失函数计算公式: -ln(\sigma_1) \sigma_1{-2t} || x - E[P(\theta, t)] ||^2
        loss = self.C * mse_loss / posterior_var
        
        return loss
  • 离散时间的损失函数

在离散时间的条件下,连续(continuous)数据的 KL loss 依然是 mse,只不过 scale 系数相比于连续时间的情况有些不同;而离散化(discretized)数据的就稍微复杂些了,由于没有解析形式,因此需要使用蒙特卡洛方法从发送者分布中进行采样去近似估计 KL 散度:

Bayesian Flow Networks(BFN)合集5_人工智能_08

def discrete_time_loss(
        self, data: Tensor,
        output_params: Tensor, input_params: Tensor,
        t: Tensor, n_steps: int, n_samples=10
    ) -> Tensor:
        # (B,-1,D)
        output_params = sandwich(output_params)
        t = t.flatten(start_dim=1).float()
        
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)

        # 离散化数据的场景
        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            t = t.flatten(start_dim=1)
            i = t * n_steps + 1  # since t = (i - 1) / n
            
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            
            flat_target = data.flatten(start_dim=1)
            # 发送者分布
            sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
            # 因为使用蒙特卡洛方法来估计发送者分布与接收者分布之间的 KL 散度,所以要从发送者分布中采样观测样本 y,
            # 采样的样本数默认为10.
            y = sender_dist.sample(torch.Size([n_samples]))
            
            # 模型输出的分配到各离散化区间的概率值. 
            #(B,D,K)
            receiver_mix_wts = sandwich(output_dist.probs)
            # 输出分布是类别分布, 在每个离散化区间都分配一定概率.
             = D.Categorical(probs=receiver_mix_wts, validate_args=False)
            # 以各离散化区间的中心为均值构造多个一维高斯分布,其中每个都与发送者分布的形式一致(噪声强度相等, 即方差一致).\
            receiver_components = D.Normal(
                output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False
            )
            # 接收者分布, 在数据的每个维度上都是混合高斯分布.
            receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)
            
            # (B,1)
            loss = (
                (sender_dist.log_prob(y) - receiver_dist.log_prob(y))  # 发送者分布和接收者分布的概率密度对数差
                .mean(0)  # 在蒙特卡洛采样的样本数上做平均
                .flatten(start_dim=1)
                .mean(1, keepdims=True)
            )
        # 连续数据的场景
        else:  # output distribution is normal
            pred_mean = output_dist.mean
            flat_target = data.flatten(start_dim=1)
            mse_loss = (pred_mean - flat_target).square()
            i = t * n_steps + 1
            alpha = self.bayesian_flow.get_alpha(i, n_steps)
            loss = alpha * mse_loss / 2
            
        return n_steps * loss

代码咋一看有些复杂,但认真看后实际上你会发现还好,和论文中的公式是能够完美对上的,至于公式推导的细节,请参考本系列第二三篇文章。

在实现时需要特别注意的是在建模离散化数据时,输出分布在每一维上都是混合高斯分布,以上使用了 torch.distributions.MixtureSameFamily 来实现,其中每个子高斯分布 receiver_components 的权重就对应输出分布在每个离散化区间上所分配的概率,注意要用这批概率去实例化一个类别分布(torch.distributions.Categorical)对象 receiver_mix_dist 并传到 MixtureSameFamily 中,并且子高斯分布 receiver_components 的个数要和类别分布 receiver_mix_dist 的类别数一致。

  • 重构损失

Bayesian Flow Networks(BFN)合集5_数据_09

def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        output_params = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        
        # 重构损失只发生在最后时刻,于是 t=1.
        t = torch.ones_like(data).flatten(start_dim=1).float()
        output_dist = self.distribution_factory.get_dist(output_params, input_params, t)
        
        if hasattr(output_dist, "probs"):  # output distribution is discretized normal
            reconstruction_loss = -output_dist.log_prob(flat_data)
        else:  # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)
            if self.bayesian_flow.min_variance == 1e-3:  # used for 16 bin CIFAR10
                noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 16
            else:
                noise_dev = math.sqrt(self.bayesian_flow.min_variance)
                num_bins = 256
                
            mean = output_dist.mean.flatten(start_dim=1)
            final_dist = D.Normal(mean, noise_dev)
            # 离散化的正态分布
            final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)
            reconstruction_loss = -final_dist.log_prob(flat_data)
            
        return reconstruction_loss

离散数据的损失函数

针对离散数据的损失函数被封装在 DiscreteBayesianFlowLoss 这个类里,它也是 Loss 的子类,同样需要实现以上提到的三项 loss 的计算逻辑。

class DiscreteBayesianFlowLoss(Loss):
    def __init__(
        self,
        bayesian_flow: DiscreteBayesianFlow,
        distribution_factory: DiscreteDistributionFactory,
    ):
        super().__init__()
        
        self.bayesian_flow = bayesian_flow
        self.distribution_factory = distribution_factory
        # 离散数据的输出分布建模为类别分布,这个变量就代表类别数量.
        self.K = self.bayesian_flow.n_classes
  • 连续时间的损失函数

建模离散数据时,连续时间的 KL loss 为:

Bayesian Flow Networks(BFN)合集5_建模_10

(具体推导过程请见本系列第四篇文章)

其中作差的两项都是 one-hot 形式。

def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:
        flat_output = sandwich(output_params)
        # 输出分布在各类别上分配的概率
        pred_probs = self.distribution_factory.get_dist(flat_output).probs
        
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)

        tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)
        kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)
        t = t.flatten(start_dim=1).float()
        loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl
        
        return loss
  • 离散时间的损失函数

Bayesian Flow Networks(BFN)合集5_人工智能_11

(具体推导过程详见本系列第四篇文章)

def discrete_time_loss(
        self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10
    ) -> Tensor:
        flat_target = data.flatten(start_dim=1)
        if self.bayesian_flow.discretize:
            flat_target = float_to_idx(flat_target, self.K)
        
        # 根据 t = \frac{i-1}{n} 反过来计算 i 
        i = t * n_steps + 1
        # \alpha_i
        alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)

        # (B,D,K)
        flat_output = sandwich(output_params)
        # 模型预测的在各个类别上的概率.
        receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs
        # 这里之所以要在倒数第2个维度上加一维是因为以下 components 在每个类别上的均值向量都是 K 维 one-hot,
        # 从而在每个类别上生成的是 K 个相互独立的正态分布. 总共有 K 类, 于是就有 K x K 个分布.
        # 因此这里增加维度是为了让 categorical 权重 与 components 对齐.
        receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))
        
        # 增加2个维度是为了对应 batch dim: B 和 data dim: D.
        classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)
        receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))
        # 接收者分布, 它是多个混合高斯分布的联合分布, 其中每个数据维度都是混合高斯分布.
        receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)
        
        sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)
        # 从发送者分布中采样, 以蒙特卡洛方法近似估计其与接收者分布之间的 KL loss
        y = sender_dist.sample(torch.Size([n_samples]))
        
        # (B,1)
        loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)
        
        return loss

在实现时,需要特别注意的是在构造混合高斯分布(也就是接收者分布)时,传到 receiver_mix_dist 中 receiver_mix_wts 的 shape 和 传给 receiver_components 的 classes 的 shape,具体细节 CW 都在以上做了注解,这里就不再重复阐述。

  • 重构损失

Bayesian Flow Networks(BFN)合集5_人工智能_12

def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:
        flat_outputs = sandwich(output_params)
        flat_data = data.flatten(start_dim=1)
        output_dist = self.distribution_factory.get_dist(flat_outputs)
        
        return -output_dist.log_prob(flat_data)

回顾论文,你会发现重构损失是在贝叶斯流分布的期望下计算的:

Bayesian Flow Networks(BFN)合集5_建模_13

五、输出分布

在前面的代码中我们经常看到 distribution_factory 这个工厂对象的出现,它的大招就是返回输出分布,不同类型的输出分布会由对应类型的工厂对象返回。经过本系列前面几篇的理论铺垫,我们知道建模不同类型的数据所用的输出分布类型也是不同的,那么这一章就一起来扒扒这些分布的代码实现。

连续型与离散型分布

首先,不论是哪种分布,都可被归类为连续型分布或离散型分布,作者分别用了两个类来表示,它们作为其余分布的基类,里面包含了作为一个分布理应具备的一些属性与方法。

CONST_log_min = 1e-10


def safe_log(data: Tensor):
    return data.clamp(min=CONST_log_min).log()


class CtsDistribution:
    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass


class DiscreteDistribution:
    @property
    @abstractmethod
    def probs(self):
        pass

    @functools.cached_property
    def log_probs(self):
        return safe_log(self.probs)

    @functools.cached_property
    def mean(self):
        pass

    @functools.cached_property
    def mode(self):
        pass

    @abstractmethod
    def log_prob(self, x):
        pass

    @abstractmethod
    def sample(self):
        pass

离散化分布

另外还有一种比较基本的分布就是离散化(discretized)分布了,它代表将一个连续型分布离散化为离散型分布,于是它最终是离散型表示,从而继承了离散型分布 DiscreteDistribution。

class DiscretizedDistribution(DiscreteDistribution):
    def __init__(self, num_bins, device):
        # 离散区间数量: K
        self.num_bins = num_bins
        # 原数据取值范围是[-1,1], 如今划分为 K 个区间, 因此每个区间宽度是 2/K.
        self.bin_width = 2.0 / num_bins
        self.half_bin_width = self.bin_width / 2.0

        self.device = device

    @functools.cached_property
    def class_centres(self):
        # 类别中心的取值范围: [-1 + 1/K, 1 - 1/K]
        return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)

    @functools.cached_property
    def class_boundaries(self):
        # 各类别之间的边界: [-1 + 2/K, 1 - 2/K], 共 K-1 个.
        return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)

    @functools.cached_property
    def mean(self):
        # 将各类别中心用它们各自所对应的概率加权求和: \sum_{k=1}^K{p_k * k_c}
        return (self.probs * self.class_centres).sum(-1)

    @functools.cached_property
    def mode(self):
        """概率分布的 mode, 代表众数, 即概率最高处所对应的样本."""

        # 因为 class_centres 是1维的, 所以这里需要将索引展平.
        mode_idx = self.probs.argmax(-1).flatten()
        return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])

上面那个类是离散化分布的父类,而要对一个连续型分布实现离散化,那么你得将它作为参数接收进来然后进行处理,于是就有了下面这个子类,它继承了上面那家伙。

class DiscretizedCtsDistribution(DiscretizedDistribution):
    """将一个连续型分布离散化."""
    
    def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):
        super().__init__(num_bins, device)

        # 原来的连续型分布, 要对其进行离散化处理.
        self.cts_dist = cts_dist
        # log(2/K)
        self.log_bin_width = log(self.bin_width)
        # B
        self.batch_dims = batch_dims
        
        # 是否要对原来连续型分布的 CDF 做截断.
        self.clip = clip
        # 用作概率的极小值
        self.min_prob = min_prob

    @functools.cached_property
    def probs(self):
        """计算数据位于各离散区间的概率."""

        # shape: [K-1] + [1] * B
        bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))
        # shape: [1] + [1] * B
        bdry_slice = bdry_cdfs[:1]
        
        if self.clip:
            '''对原来连续型分布的 CDF 做截断: 小于第一个区间的左端概率置0、小于等于最后一个区间右端的概率置1.'''
            
            cdf_min = torch.zeros_like(bdry_slice)
            cdf_max = torch.ones_like(bdry_slice)
            # shape: [K+1] + [1] * B
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)

            # 利用 CDF(k_r) - CDF(k_l) 得到位于各区间的概率.
            # shape: [1] * B + [K]
            return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)
        else:
            '''以条件概率的思想来计算数据位于各区间的概率,其中的条件就是数据位于 [-1,1] 取值范围内.
            先计算原连续型分布在 1 和 -1 处的 CDF 值,将两者作差从而得到位于 [-1,1] 内的概率,以此作为条件对各区间的概率进行缩放.'''

            # CDF(-1)
            cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)
            # CDF(1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))
            # shape: [K+1] + [1] * B
            bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)

            # p_{-1 < x <= 1}
            cdf_range = cdf_max - cdf_min
            cdf_mask = cdf_range < self.min_prob
            # 当 cdf_range 小于就以 1 代替, 避免作为分母时造成结果溢出.
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)

            # shape: [K] + [1] * B
            probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range
            # 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将每个区间的概率都用 1/K 即均等的概率代替.
            probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)

            # shape: [1] * B + [K]
            return probs.moveaxis(0, -1)

    def prob(self, x):
        # 区间索引 k \in [0, K-1]
        class_idx = float_to_idx(x, self.num_bins)
        # 区间中心 k_c
        centre = idx_to_float(class_idx, self.num_bins)
        # CDF(k_l), 其中 k_l 代表区间左端点.
        cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)
        # CDF(k_r), 其中 k_r 代表区间右端点.
        cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)
        
        if self.clip:
            '''对原来连续型分布的 CDF 做截断, 使得:
            CDF(k <= 0) = 0;
            CDF(k >= K-1) = 1'''
            
            cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)
            cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)
            
            return cdf_hi - cdf_lo
        else:
            '''以条件概率的思想来计算数据位于某个离散区间内的概率,其中的条件就是数据位于 [-1,1] 取值范围内.
            先计算原连续型分布在 1 和 -1 处的 CDF 值,将两者作差从而得到位于 [-1,1] 内的概率,以此作为条件对区间的概率进行缩放.'''
            
            cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(centre))
            cdf_range = cdf_max - cdf_min
            
            # 若 cdf_range 太小,则设置 mask,并将其以1代替,即不对区间的概率进行缩放, 否则会使得计算出来的采样概率非常接近于1.
            # 两个非常小的值相除, 由于它们都很小、非常接近,因此商接近于1.
            cdf_mask = cdf_range < self.min_prob
            cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)
            prob = (cdf_hi - cdf_lo) / cdf_range
            
            # 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将区间的概率都用 1/K 即均等的概率代替.
            return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)

    def log_prob(self, x):
        prob = self.prob(x)

        return torch.where(
            prob < self.min_prob,
            # 将 x 以对应区间的中点 k_c 表示并计算出其在原来连续分布中的对数概率密度: log(p(k_c)).
            # 这里加上 log(2/K) 相当于将 k_c 乘以 2/K 再取对数.
            self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,
            safe_log(prob),
        )

    def sample(self, sample_shape=torch.Size([])):
        if self.clip:
            # 直接从原来的连续型分布中采样, 然后将其量化至对应的离散化区间.
            # 此处, clip 的意思是:
            # 若小于第一个区间,则以第一个区间中点表示;
            # 同理,若大于最后一个区间,则以最后一个区间的中点表示.
            return quantize(self.cts_dist.sample(sample_shape), self.num_bins)
        else:
            # 要求原来连续型分布的 CDF 存在反函数, 即可以根据概率值逆向求出对应的样本.
            assert hasattr(self.cts_dist, "icdf")
            
            # 数据的取值范围是 [-1,1], 先根据原来的连续型分布计算出 CDF(-1) 和 CDF(1),
            # 然后利用 CDF 的反函数仅在这个 range 内考虑采样.
            cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)
            cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))

            # 由于 CDF 是服从均匀分布的, 因此从均匀分布中采样出 CDF 值并利用反函数求出对应样本就等价于从目标分布中采样.
            u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)
            cts_samp = self.cts_dist.icdf(u)

            # 最后将样本量化至对应的离散化区间.
            # 注意, 与前面 clip 的方式不同, 此处在量化前样本已经处于有效的离散化区间内了, 因为采样区间是在[-1,1]内考虑的.
            return quantize(cts_samp, self.num_bins)

别被吓倒,虽然代码看起来很复杂,但 CW 已经在上面做了详细的注解,结合对数据进行离散化的知识,理解上面的代码应该是 NO problem 的!

(关于离散化操作的知识背景,可参考本系列第三篇文章)

Discretized Normal Distribution

在建模离散化数据时,输出分布是离散的正态分布:

Bayesian Flow Networks(BFN)合集5_人工智能_14

以下就是这个离散正态分布的实现,它继承了前面展示的 DiscretizedCtsDistributon 类。得益于它的爸爸(帮它将核心功能都实现完了),你可以看到这个类非常躺平(实现得非常简单)~

CONST_exp_range = 10


def safe_exp(data: Tensor):
    return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()


class DiscretizedNormal(DiscretizedCtsDistribution):
    def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        assert params.size(-1) == 2
        
        if min_std_dev < 0:
            min_std_dev = 1.0 / (num_bins * 5)
            
        mean, std_dev = params.split(1, -1)[:2]
        if log_dev:
            # 若传入的是对数标准差, 那么此处就需要取自然指数进行还原.
            std_dev = safe_exp(std_dev)
        std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)
        
        super().__init__(
            cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),
            num_bins=num_bins,
            device=params.device,
            # 注意所谓的 batch dims 并非指数据的 batch size,
            # 而是除离散化区间数量以外与分布本身关系不大的其它维度.
            batch_dims=params.ndim - 1,
            clip=clip,
            min_prob=min_prob,
        )

Delta Distribution

建模连续数据时,输出分布为 Delta 分布,它像是个“山寨分布”似的,只有一个单点,实现起来比起前面那位离散的正态分布,有过之而无不及,人家靠的是爸爸才躺平,而它仅靠自己也很躺..

class DeltaDistribution(CtsDistribution):
    def __init__(self, mean, clip_range=1.0):
        if clip_range > 0:
            mean = mean.clip(min=-clip_range, max=clip_range)
        self.mean = mean

    def mode(self):
        return self.mean

    def mean(self):
        return self.mean

    def sample(self, sample_shape=torch.Size([])):
        return self.mean

既然你这么躺那么我 CW 也小躺一下——注解我就不做了(傲娇脸)。

Bernoulli Distribution

对于二值的离散数据,输出分布为伯努利分布,作者在 MNIST 的实验中就是这么玩的,他对 MNIST 数据集进行了二值化处理(关于二值化的实现 CW 在后文会详细解析),使其变身为动态二值化(dynamically binarized)的 MNIST。

from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli


class Bernoulli(DiscreteDistribution):
    def __init__(self, logits):
        self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)

    @functools.cached_property
    def probs(self):
        p = self.bernoulli.probs.unsqueeze(-1)
        return torch.cat([1 - p, p], -1)

    @functools.cached_property
    def mode(self):
        return self.bernoulli.mode

    def log_prob(self, x):
        return self.bernoulli.log_prob(x.float())

    def sample(self, sample_shape=torch.Size([])):
        return self.bernoulli.sample(sample_shape)

从上面的代码可以看出,这个分布实现起来也非常 easy,主要靠的是 Pytorch 内置的实现,只不过 torch 的实现在返回 probs 时仅返回了单类概率,作者则额外补充了剩下一类的概率并将两者拼接起来返回。

Categorical Distribution

当面对多类别的离散数据时,输出分布理应就是类别分布了,作者对于 text8 数据集的建模就采取了这种玩法。

from torch.distributions.categorical import Categorical as torch_Categorical


class Categorical(DiscreteDistribution):
    def __init__(self, logits):
        self.categorical = torch_Categorical(logits=logits, validate_args=False)
        self.n_classes = logits.size(-1)

    @functools.cached_property
    def probs(self):
        return self.categorical.probs

    @functools.cached_property
    def mode(self):
        return self.categorical.mode

    def log_prob(self, x):
        return self.categorical.log_prob(x)

    def sample(self, sample_shape=torch.Size([])):
        return self.categorical.sample(sample_shape)

这里的实现完全依赖了 Pytorch 内置的实现,只不过额外记录了类别数 n_classes 这一属性。

好家伙,在建模离散数据时居然靠着开源的力量躺平~!

Distribution Factory

前面提到,输出分布是由对应的工厂对象制作出来的,以下就是各种工厂类的实现,它们都分别继承连续型分布的工厂类或离散型分布的工厂类,这两个类都是抽象基类,定义了子类必须实现的方法 get_dist() —— 返回一个指定参数的分布。

class CtsDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:
        """Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory."""
        pass


class DiscreteDistributionFactory:
    @abstractmethod
    def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:
        """Note: input_params and t are only required by PredDistToDataDistFactory."""
        pass


class DiscretizedNormalFactory(DiscreteDistributionFactory):
    def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):
        self.num_bins = num_bins
        self.clip = clip
        self.min_std_dev = min_std_dev
        self.max_std_dev = max_std_dev
        self.min_prob = min_prob
        self.log_dev = log_dev

    def get_dist(self, params, input_params=None, t=None):
        return DiscretizedNormal(
            params,
            num_bins=self.num_bins,
            clip=self.clip,
            min_std_dev=self.min_std_dev,
            max_std_dev=self.max_std_dev,
            min_prob=self.min_prob,
            log_dev=self.log_dev,
        )


class DeltaFactory(CtsDistributionFactory):
    def __init__(self, clip_range=1.0):
        self.clip_range = clip_range

    def get_dist(self, params, input_params=None, t=None):
        return DeltaDistribution(params.squeeze(-1), self.clip_range)


class BernoulliFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Bernoulli(logits=params.squeeze(-1))


class CategoricalFactory(DiscreteDistributionFactory):
    def get_dist(self, params, input_params=None, t=None):
        return Categorical(logits=params)
  • 将噪声分布转换为数据分布

前面在解析连续和离散化数据的损失函数 CtsBayesianFlowLoss 时,有以下这样一段代码:

class CtsBayesianFlowLoss(Loss):
    """建模连续/离散化数据场景时所用的损失函数, 包括:
    -离散时间损失函数;
    -连续时间损失函数;
    -重构损失"""

    def __init__(
        self,
        bayesian_flow: CtsBayesianFlow,
        distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],
        min_loss_variance: float = -1,
        noise_pred: bool = True,
    ):  
        ...  # 此处省略一大段
      
        # 是否预测噪声(亦或是直接预测数据)
        self.noise_pred = noise_pred
        if self.noise_pred:
            self.distribution_factory.log_dev = False
            # 在预测噪声的情况下, 将预测的噪声(或噪声分布相关的参数)转换为对应数据分布(输出分布)的参数.
            self.distribution_factory = PredDistToDataDistFactory(
                self.distribution_factory, self.bayesian_flow.min_variance
            )

也就是说,当模型输出(预测)的是噪声变量或噪声分布的参数时,需要将其转换为对应生成的目标数据或目标数据分布(输出分布)所对应的参数。 而前面已经说过,目标数据分布都是由对应的工厂类制造的,于是这个转换过程就由工厂类去实现,这个工厂类就是 PredDistToDataDistFactory。

class PredDistToDataDistFactory(DiscreteDistributionFactory):
    def __init__(self, data_dist_factory, min_variance, min_t=1e-6):
        self.data_dist_factory = data_dist_factory
        # 之所以设为 False 是因为在以下 noise_pred_params_to_data_pred_params() 方法中会将对数标准差使用自然指数进行转换,
        # 而无需原数据分布的工厂自行转换.
        self.data_dist_factory.log_dev = False
        self.min_variance = min_variance
        self.min_t = min_t

    def get_dist(self, params, input_params, t):
        data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)
        return self.data_dist_factory.get_dist(data_pred_params)

可以看到,它将目标数据分布的工厂对象 data_dist_factory 作为属性,并依靠后者来返回目标数据分布,只不过预先调用了一个将噪声分布相关的参数转换为数据分布相关参数的方法 noise_pred_params_to_data_pred_params()。既然如此,我们就顺藤摸瓜地深入到这个方法中去寻找真理叭~

哦,在看代码之前,还是一起先来回顾下在建模连续和离散化数据时噪声(分布)转换为目标数据分布(即输出分布)的公式,以便和接下来的代码进行对照。

在建模连续数据时,模型预测(输出)的是单个噪声变量(其实是噪声分布的均值向量),对应转换成单点的输出分布(Delta 分布):

Bayesian Flow Networks(BFN)合集5_数据_15

(对于上面公式的具体推导过程可参考本系列第二、三篇文章)     

def noise_pred_params_to_data_pred_params(
    noise_pred_params: torch.Tensor, input_mean: torch.Tensor,
    t: torch.Tensor, min_variance: float, min_t=1e-6
):
    """Convert output parameters that predict the noise added to data, to parameters that predict the data.
    将模型预测的噪声分布的参数转换为数据分布的参数."""

    # (B,L,D)
    data_shape = list(noise_pred_params.shape)[:-1]
    # (B,L*D,NP), NP: num parameters per data
    noise_pred_params = sandwich(noise_pred_params)
    # (B,L*D)
    input_mean = input_mean.flatten(start_dim=1)
    
    if torch.is_tensor(t):
        t = t.flatten(start_dim=1)
    else:
        t = (input_mean * 0) + t
        
    # (B,L*D,1)
    alpha_mask = (t < min_t).unsqueeze(-1)
    
    # \sigma_1^{2t}
    posterior_var = torch.pow(min_variance, t.clamp(min=min_t))
    # \gamma(t) = 1 - \sigma_1^{2t}
    gamma = 1 - posterior_var

    # \frac{\mu}{\gamma(t)}
    A = (input_mean / gamma).unsqueeze(-1)
    # \sqrt{\frac{1-\gamma(t)}{\gamma(t)}}
    B = (posterior_var / gamma).sqrt().unsqueeze(-1)
    
    data_pred_params = []
    
    # 对应建模连续数据的场景: 模型预测的是噪声向量.
    if noise_pred_params.size(-1) == 1:
        noise_pred_mean = noise_pred_params
    # 对应建模离散化数据的场景: 模型预测的是噪声分布的均值与对数标准差. 
    elif noise_pred_params.size(-1) == 2:
        noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)
    else:
        assert noise_pred_params.size(-1) % 3 == 0
        mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)
        data_pred_params.append(mix_wt_logits)

    # 连续数据: x = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \epsilon
    # 离散化数据: \mu_{x} = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \mu_{\epsilon}
    data_pred_mean = A - (B * noise_pred_mean)
    # 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的均值置0
    data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)
    data_pred_params.append(data_pred_mean)
    
    if noise_pred_params.size(-1) >= 2:
        # 将对数标准差取自然指数复原: exp(ln(\sigma_{\epsilon})) -> \sigma_{\epsilon}
        noise_pred_dev = safe_exp(noise_pred_log_dev)
        # 将噪声分布的标准差转换为目标数据分布的标准差: \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} exp(ln(\sigma_{\epsilon})) -> \mu_x
        data_pred_dev = B * noise_pred_dev
        # 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的标准差置1
        data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)
        data_pred_params.append(data_pred_dev)

    # (B,L*D,NP)
    data_pred_params = torch.cat(data_pred_params, -1)
    # (B,L,D,NP)
    data_pred_params = data_pred_params.reshape(data_shape + [-1])
    
    return data_pred_params

Bayesian Flow Networks(BFN)合集5_数据_16

六、神经网络的实现

BFN 更多的是一种方法论,而与具体的模型架构无关,于是对神经网络的架构是没有限制的,作者一共实现了3种模型:UNet-VDM(变分扩散模型)、UNet 以及 GPT(参考了 Karpathy 大神的 nano-GPT)。CW 比较懒,实在不愿意将3种模型全部写了,所以干脆就挑最新的 VDM 在这里吹吹水吧~

Model

@torch.no_grad()
def zero_init(module: nn.Module) -> nn.Module:
    """Sets to zero all the parameters of a module, and returns the module."""
    
    for p in module.parameters():
        nn.init.zeros_(p.data)
        
    return module


class UNetVDM(nn.Module):
    def __init__(
        self,
        data_adapters,
        embedding_dim: int = 128,
        n_blocks: int = 32,
        n_attention_heads: int = 1,
        dropout_prob: float = 0.1,
        norm_groups: int = 32,
        input_channels: int = 3,
        use_fourier_features: bool = True,
        attention_everywhere: bool = False,
        image_size: int = 32,
    ):
        super().__init__()

        # 对输入进行前置处理, 比如加入位置编码.
        self.input_adapter = data_adapters["input_adapter"]
        # 将输出转换为目标 形式, 通常是将维度数 project 到指定数.
        self.output_adapter = data_adapters["output_adapter"]
        
        attention_params = dict(
            n_heads=n_attention_heads,
            n_channels=embedding_dim,
            norm_groups=norm_groups,
        )
        
        resnet_params = dict(
            ch_in=embedding_dim,
            ch_out=embedding_dim,
            condition_dim=4 * embedding_dim,
            dropout_prob=dropout_prob,
            norm_groups=norm_groups,
        )
        
        self.embed_conditioning = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.SiLU(),
            nn.Linear(embedding_dim * 4, embedding_dim * 4),
            nn.SiLU(),
        )
        
        total_input_ch = input_channels
        if use_fourier_features:
            self.fourier_features = FourierFeatures()
            # C = (2F + 1)C, 其中 2F 代表傅里叶特征数(sin & cos 各占 F).
            # 经过傅里叶特征变换所输出的通道数为 2FC, 而这部分特征会和原特征拼接起来,
            # 于是通道数总共就为 (2F+1)C.
            total_input_ch *= 1 + self.fourier_features.num_features
            
        self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1)

        # Down path: n_blocks blocks with a resnet block and maybe attention.
        self.down_blocks = nn.ModuleList(
            # 注意, 实际上并没有下采样, 分辨率保持不变.
            UpDownBlock(
                resnet_block=ResnetBlock(**resnet_params),
                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
            )
            for _ in range(n_blocks)
        )

        self.mid_resnet_block_1 = ResnetBlock(**resnet_params)
        self.mid_attn_block = AttentionBlock(**attention_params)
        self.mid_resnet_block_2 = ResnetBlock(**resnet_params)

        # Up path: n_blocks+1 blocks with a resnet block and maybe attention.
        resnet_params["ch_in"] *= 2  # double input channels due to skip connections
        self.up_blocks = nn.ModuleList(
            # 注意, 实际上并没有上采样, 分辨率保持不变.
            UpDownBlock(
                resnet_block=ResnetBlock(**resnet_params),
                attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,
            )
            for _ in range(n_blocks + 1)
        )

        self.conv_out = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim),
            nn.SiLU(),
            # 将最后的输出卷积层初始化为全0.
            zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)),
        )
        
        self.embedding_dim = embedding_dim
        self.input_channels = input_channels
        self.image_size = image_size
        self.use_fourier_features = use_fourier_features

    def forward(
        self,
        data: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        # (B,H*W,C)
        flat_x = self.input_adapter(data, t)
        # (B,H,W,C)
        x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels)

        # (B,) 因为同一个数据样本在各维度上所对应的时间变量一致, 所以只需要取同的样本的其中1个维度即可.
        t = t.float().flatten(start_dim=1)[:, 0]
        # (B,D) 这里 + 0.001 代表小于 0.001 即看作是起始时刻(因此起始时刻不为0), 与 paper 中的描述一致.
        t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim)
        # We will condition on time embedding.
        # (B,4D)
        cond = self.embed_conditioning(t_embedding)

        # (B,C,H,W)
        x_perm = x.permute(0, 3, 1, 2).contiguous()
        # 若设定了要使用傅里叶特征, 则将傅里叶特征拼接过来.
        # (B,(2F+1)C,H,W), 其中 2FC 是傅里叶特征变换模块输出的通道数.
        h = self.maybe_concat_fourier(x_perm)
        # (B,D,H,W)
        h = self.conv_in(h)
        
        hs = [h]
        for down_block in self.down_blocks:  # n_blocks times
            h = down_block(h, cond)
            hs.append(h)
        
        h = self.mid_resnet_block_1(h, cond)
        h = self.mid_attn_block(h)
        h = self.mid_resnet_block_2(h, cond)
        
        for up_block in self.up_blocks:  # n_blocks+1 times
            h = torch.cat([h, hs.pop()], dim=1)
            h = up_block(h, cond)
        
        # (B,H*W,D)
        # 这个最后的卷积层初始化为全0, 因此在参数更新前这个输出特征不起作用,
        # 于是以下才将网络的输入也一并拼接在一起再输入到最后的 linear projection.
        out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous())
        # (B,H*W,C+D)
        out = torch.cat([sandwich(x), out], -1)
        # (B,H*W,out_channels,out_height)
        out = self.output_adapter(out)
        
        return out

    def maybe_concat_fourier(self, z):
        if self.use_fourier_features:
            return torch.cat([z, self.fourier_features(z)], dim=1)
        
        return z

代码也比较好懂,模型的编解码方式遵循 U-Net 的玩法,但在这里却没有真正地进行上、下采样,分辨率一直是保持不变的。 另外,在编解码的同时加入了时间变量(将其处理为 embeddings),以使得模型对于时间拥有感知能力。至于提取特征的基本组件也是老套路了:resnet blcok & self attention。

与常规 VDM 相比,这里的实现有几点比较特殊:

Bayesian Flow Networks(BFN)合集5_人工智能_17

Input Aadapter

作者实现了两种 input adpater,分别是用于语言的 TextInputAdapter 和用于图像的 FourierImageInputAdapter,两者的实质其实都是制作 position embeddings 和 time embeddings 并且附加(element-wise add or concat)在原输入变量上。但这里的 time embeddings 和后文即将展示的 get_timestep_embedding() 中的概念不同,这里主要是对时间变量进行 scale(从而将其取值范围从 [0,1] 缩放至 [-1,1],与输入数据一致),而并不一定对它再进行额外的 projection。

TextInputAdapter 中的 position embeddings 是我们比较熟悉的方式:可学习的 embeddings 或 正弦位置编码;而 FourierImageInputAdapter 在可学习的 embeddings 之余还可能使用傅里叶位置编码,如其名。

def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:
    """Positional encoding as described in original attention is all you need paper"""

    pos = torch.arange(sequence_length).unsqueeze(1)

    pe = torch.zeros((sequence_length, embedding_size))    
    pe[:, 0::2] = torch.sin(
        pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size)
    )
    pe[:, 1::2] = torch.cos(
        pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size)
    )

    return pe


class TextInputAdapter(nn.Module):
    """
    A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings.
    """

    def __init__(
        self,
        vocab_size: int,
        seq_len: int,
        output_size: int = 256,
        learn_pos_embedding: bool = False,
    ):
        super().__init__()
        
        self.learn_pos_embedding = learn_pos_embedding
        if learn_pos_embedding:
            self.pos_embedding = nn.Embedding(seq_len, output_size)
        else:
            self.register_buffer("pos_embedding", pe_encode(seq_len, output_size))
            
        self.inp_embedding = nn.Linear(vocab_size, output_size)
        self.t_embedding = nn.Linear(1, output_size)

    def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:
        # 将概率值从[0,1]范围缩放至[-1,1]
        inp_emb = self.inp_embedding(2 * probs - 1)
        
        if self.learn_pos_embedding:
            pos_emb = self.pos_embedding(
                torch.arange(0, probs.size(1)).to(probs.device)
            )
        else:
            pos_emb = self.pos_embedding
        # (B,L,output_size)
        pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1)
        
        # 同样将时间变量的范围从[0,1]缩放至[-1,1]
        t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1))
        
        output = inp_emb + pos_emb + t_emb

        return output


def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor:
    pos = (((x + 1) / 2) * max_freq).unsqueeze(-1)
    
    pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device)
    pe[..., 0::2] = torch.sin(
        pos
        / torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
    )
    pe[..., 1::2] = torch.cos(
        pos
        / torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size)
    )
    
    return pe


class FourierImageInputAdapter(nn.Module):
    """
    A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes.
    """

    def __init__(
        self,
        input_channels: int = 3,
        input_shape: Tuple[int, int] = (224, 224),
        n_freq_bands: int = 64,
        output_height: int = 256,
        value_res: int = -1,
        mask_res: int = -1,
        add_pos_feats: bool = True,
        add_mask: bool = True,
        learn_pos_feats: bool = False,
        pos_embed_size: int = 32,
        init_scale: float = 0.02,
    ):
        super().__init__()
        
        self.input_shape = input_shape
        self.n_freq_bands = n_freq_bands
        self.value_res = value_res
        self.mask_res = mask_res
        self.add_pos_feats = add_pos_feats
        self.add_mask = add_mask
        
        if learn_pos_feats:
            pos_feats = nn.Parameter(
                init_scale
                * torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size)
            )
            self.register_parameter("pos_feats", pos_feats)
        else:
            x = torch.linspace(-1.0, 1.0, steps=input_shape[0])
            y = torch.linspace(-1.0, 1.0, steps=input_shape[1])
            # (input_shape[0],input_shape[1])
            x_pos, y_pos = torch.meshgrid(x, y, indexing="ij")
            # (input_shape[0],input_shape[1],2)
            pos = torch.stack((x_pos, y_pos), dim=-1)
            # (L=H*W,2)
            pos = pos.reshape(-1, 2)
            
            x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands)
            y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands)
            # (2,n_freq_bands)
            bands = torch.stack((x_bands, y_bands), dim=0)
            
            # (L,2,n_freq_bands)
            vals = pos[:, :, None] * bands[None, :, :]
            # (L,2*n_freq_bands)
            vals = math.pi * vals.reshape(vals.shape[0], -1)
            # (L,4*n_freq_bands)
            pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1)
            # (L,4*n_freq_bands+2)
            pos_feats = torch.cat([pos_feats, pos], dim=-1)
            self.register_buffer("pos_feats", pos_feats)
            
        img_feat_height = input_channels
        pos_feat_height = pos_feats.size(-1)
        
        if self.mask_res > 0:
            mask_feat_height = (n_freq_bands * 2) + 1
        else:
            mask_feat_height = 1
            
        all_feat_height = img_feat_height
        if add_mask:
            all_feat_height += mask_feat_height
        if add_pos_feats:
            all_feat_height += pos_feat_height
            
        self.output_projection = None
        if output_height != all_feat_height:
            self.output_projection = nn.Linear(all_feat_height, output_height)

    def forward(self, img: Tensor, t: Tensor) -> Tensor:
        # (B,H*W,C)
        flat_img = sandwich(img)
        # (B,H*W,C)
        flat_t = sandwich(t)
        
        # [0,1] -> [-1,1]
        t_feats = (flat_t.float()[..., :1] * 2) - 1
        if self.mask_res > 0:
            t_feats = torch.cat(
                [
                    t_feats,
                    pe_encode_float(
                        t_feats, self.mask_res, self.n_freq_bands * 2
                    ).flatten(start_dim=2),
                ],
                -1,
            )
        
        # (B, H*W, )
        fourier_feats = self.pos_feats.expand(img.size(0), -1, -1)
        
        all_feat_list = [flat_img]
        if self.add_mask:
            all_feat_list.append(t_feats)
        if self.add_pos_feats:
            all_feat_list.append(fourier_feats)
        all_feats = torch.cat(all_feat_list, dim=-1)
        
        if self.output_projection is None:
            output = all_feats
        else:
            output = self.output_projection(all_feats)
            
        return output

Output Adapter

output adapter 实质上就是 projection layer(Linear 层),用于将特征的最后一维映射至指定数目,以满足指定的输出分布形式。

class OutputAdapter(nn.Module):
    def __init__(self, input_height: int, output_channels: int, output_height: int):
        super().__init__()
        
        self.output_channels = output_channels
        self.output_height = output_height
        self.output_projection = nn.Linear(
            input_height, output_channels * output_height
        )

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        output = self.output_projection(inp)
        return output.reshape(
            output.size(0), -1, self.output_channels, self.output_height
        )

Time Embedding

Bayesian Flow Networks(BFN)合集5_建模_18

最后,再将序列分别送进正余弦函数后拼接起来从而恢复为原来 embedding 的维度。

def get_timestep_embedding(
    timesteps,
    embedding_dim: int,
    dtype=torch.float32,
    max_timescale=10_000,
    min_timescale=1,
):
    """正弦位置编码, 相当于将时间变量的值看作是位置."""
    
    # Adapted from tensor2tensor and VDM codebase.
    assert timesteps.ndim == 1
    assert embedding_dim % 2 == 0
    
    num_timescales = embedding_dim // 2
    # num_timescales 个等比元素, 由 1/min_timescale 到 1/max_timescale(包含).
    # logspace 的底默认为 10, 其输入的前两个参数代表起始和终止的幂
    inv_timescales = torch.logspace(  # or exp(-linspace(log(min), log(max), n))
        -np.log10(min_timescale),
        -np.log10(max_timescale),
        num_timescales,
        device=timesteps.device,
    )
    
    timesteps *= 1000.0  # In DDPM the time step is in [0, 1000], here [0, 1]
    emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :]  # (T, D/2)
    
    # sin(t * \frac{1}{10000^{i/d}}), cos(t * \frac{1}{10000^{i/d}})
    return torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, D)

Others

剩下的 modules 主要包括:负责提取傅里叶特征的 FourierFeatures 、实现自注意力的 Attention、ResNet 的套路 ResnetBlock 以及将 ResnetBlock & Attention 放在一起玩以模仿 UNet 但实际并未进行上下采样的 UpDownBlock。这里就不再逐一详细解析了,直接看代码就能 get 到对应的意思。

  • FourierFeatures
class FourierFeatures(nn.Module):
    def __init__(self, first=5.0, last=6.0, step=1.0):
        super().__init__()
        self.freqs_exponent = torch.arange(first, last + 1e-8, step)

    @property
    def num_features(self):
        return len(self.freqs_exponent) * 2

    def forward(self, x):
        assert len(x.shape) >= 2

        # Compute (2pi * 2^n) for n in freqs.
        freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device)  # (F, )
        freqs = 2.0**freqs_exponent * 2 * pi  # (F, )
        freqs = freqs.view(-1, *([1] * (x.dim() - 1)))  # (F, 1, 1, ...)

        # Compute (2pi * 2^n * x) for n in freqs.
        features = freqs * x.unsqueeze(1)  # (B, F, X1, X2, ...)
        features = features.flatten(1, 2)  # (B, F * C, X1, X2, ...)

        # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).
        return torch.cat([features.sin(), features.cos()], dim=1)
  • Attention
def attention_inner_heads(qkv, num_heads):
    """Computes attention with heads inside of qkv in the channel dimension.

    Args:
        qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:
            H = number of heads,
            C = number of channels per head.
        num_heads: number of heads.

    Returns:
        Attention output of shape (B, H*C, T).
    """

    bs, width, length = qkv.shape
    ch = width // (3 * num_heads)

    # Split into (q, k, v) of shape (B, H*C, T).
    q, k, v = qkv.chunk(3, dim=1)

    # 对 Q, K 各自缩放 1/d^{1/4} 相当于 Q, K 矩阵相乘后的结果缩放了 1/(\sqrt{d})
    # Rescale q and k. This makes them contiguous in memory.
    scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrt
    q = q * scale
    k = k * scale

    # Reshape qkv to (B*H, C, T).
    new_shape = (bs * num_heads, ch, length)
    q = q.view(*new_shape)
    k = k.view(*new_shape)
    v = v.reshape(*new_shape)

    # Compute attention.
    weight = einsum("bct,bcs->bts", q, k)  # (B*H, T, T)
    weight = softmax(weight.float(), dim=-1).to(weight.dtype)  # (B*H, T, T)
    out = einsum("bts,bcs->bct", weight, v)  # (B*H, C, T)
    
    return out.reshape(bs, num_heads * ch, length)  # (B, H*C, T)


class Attention(nn.Module):
    """Based on https://github.com/openai/guided-diffusion."""

    def __init__(self, n_heads):
        super().__init__()
        
        self.n_heads = n_heads

    def forward(self, qkv):
        assert qkv.dim() >= 3, qkv.dim()
        assert qkv.shape[1] % (3 * self.n_heads) == 0
        
        spatial_dims = qkv.shape[2:]
        qkv = qkv.view(*qkv.shape[:2], -1)  # (B, 3*n_heads*C, T)
        out = attention_inner_heads(qkv, self.n_heads)  # (B, n_heads*C, T)
        
        return out.view(*out.shape[:2], *spatial_dims).contiguous()


class AttentionBlock(nn.Module):
    """Self-attention residual block."""

    def __init__(self, n_heads, n_channels, norm_groups):
        super().__init__()
        
        assert n_channels % n_heads == 0
        
        self.layers = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
            # 之所以将通道数扩展3倍是因为后续要输入到 Attention 模块, 为 Q, K ,V 各分配数量一致的通道数.
            nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),  # (B, 3 * C, H, W)
            Attention(n_heads),
            # 输出卷积层初始化为全0,因此在参数更新前这部分输出特征相当于不起作用.
            zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
        )

    def forward(self, x):
        return self.layers(x) + x
  • ResnetBlock
class ResnetBlock(nn.Module):
    def __init__(
        self,
        ch_in,
        ch_out=None,
        condition_dim=None,
        dropout_prob=0.0,
        norm_groups=32,
    ):
        super().__init__()
        
        ch_out = ch_in if ch_out is None else ch_out
        
        self.ch_out = ch_out
        self.condition_dim = condition_dim
        
        self.net1 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),
            nn.SiLU(),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
        )
        
        if condition_dim is not None:
            self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))
        
        self.net2 = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),
            nn.SiLU(),
            nn.Dropout(dropout_prob),
            zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),
        )
        
        if ch_in != ch_out:
            self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)

    def forward(self, x, condition):
        h = self.net1(x)
        
        if condition is not None:
            assert condition.shape == (x.shape[0], self.condition_dim)
            
            # 这个条件映射层(全连接层)初始化为全0, 因此在参数更新前条件变量不起作用.
            condition = self.cond_proj(condition)
            # (B,D,1,1)
            condition = condition[:, :, None, None]
            h = h + condition
        
        h = self.net2(h)
        
        if x.shape[1] != self.ch_out:
            x = self.skip_conv(x)
        assert x.shape == h.shape
        
        return x + h
  • UpDownBlock
class UpDownBlock(nn.Module):
    def __init__(self, resnet_block, attention_block=None):
        super().__init__()
        
        self.resnet_block = resnet_block
        self.attention_block = attention_block

    def forward(self, x, cond):
        x = self.resnet_block(x, cond)
        if self.attention_block is not None:
            x = self.attention_block(x)
            
        return x

七、数据加载与预处理

接下来将对数据集加载及预处理的代码实现进行解析,作者在实验中使用的数据集有3种:CIFAR-10、MNIST 以及 TEXT8,除后者是 NLP 玩家专享之外,前两者都是 CV 玩家们的宝宝。

至于我将重点解析的预处理,则都是针对 CV 数据集的,主要是对于 CIFAR-10(RGB 彩图) 的离散化操作 和 专门为 MNIST(单通道灰度图) 设计的动态二值化操作。

数据集加载

CV 的数据集用的是 torchvision 自带的,由于这里不需要图像标签,因此只返回图像本身,这点可从以下 CIFAR10 和 MNIST 的 __getitem__() 方法里看出。至于 TEXT8 数据集,则是从 URL 下载下来后再做一些字符串处理操作。

import numpy as np

import torch
import torchvision

from torchvision import transforms
from torch.utils.data import Dataset, random_split


class MyLambda(torchvision.transforms.Lambda):
    def __init__(self, lambd, arg1):
        super().__init__(lambd)
        self.arg1 = arg1

    def __call__(self, x):
        return self.lambd(x, self.arg1)


class CIFAR10(torchvision.datasets.CIFAR10):
    def __getitem__(self, idx):
        return super().__getitem__(idx)[0]


class MNIST(torchvision.datasets.MNIST):
    def __getitem__(self, idx):
        return super().__getitem__(idx)[0]


def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:
    """
    Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir
    Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False)
    Mandatory for text: seq_len
    """
    
    num_bins = cfg.get("num_bins", 256)
    
    if cfg.dataset == "cifar10":
        train_transform_list = [transforms.ToTensor()]
        if cfg.get("horizontal_flip", False):
            train_transform_list.append(transforms.RandomHorizontalFlip())
        train_transform_list.append(MyLambda(rgb_image_transform, num_bins))
        train_transform = transforms.Compose(train_transform_list)
        
        test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)])
        
        train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform)
        val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform)
        test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform)

    elif cfg.dataset == "mnist":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                MyLambda(rgb_image_transform, num_bins),
            ]
        )
        
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "bin_mnist":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)])
        
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "bin_mnist_cts":
        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)])
        train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)
        test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)

    elif cfg.dataset == "text8":
        train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len)
        val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len)
        test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len)
        
    else:
        raise NotImplementedError(cfg.dataset)

    if cfg.dataset != "text8":
        # For vision datasets we split the train set into train and val
        # 因为上面划分的 train_set 和 val_set 实际上都是训练集,只不过应用了不同的 transforms,
        # 所以这里需要按比例从训练集中真正划分出验证集.
        val_frac = cfg.get("val_frac", 0.01)
        train_val_split = [1.0 - val_frac, val_frac]
        
        # 固定随机种子使得两个 random_split 划分的结果一致, 这样 train_set 和 val_set 就不用有交集.
        seed = 2147483647
        train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0]
        val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1]

    return train_set, val_set, test_set


def prepare_text8(data_dir: pathlib.Path):
    data_dir.mkdir(parents=True, exist_ok=True)
    data_url = "http://mattmahoney.net/dc/text8.zip"
    
    with open(data_dir / "text8.zip", "wb") as f:
        print("Downloading text8")
        f.write(requests.get(data_url).content)
        print("Done")
        
    with zipfile.ZipFile(data_dir / "text8.zip") as f:
        f.extractall(data_dir)
    os.remove(data_dir / "text8.zip")
    
    data = (data_dir / "text8").read_text()

    # get all the unique characters that occur in this text
    chars = sorted(list(set(data)))
    vocab_size = len(chars)
    print("all the unique characters:", "".join(chars))
    print(f"vocab size: {vocab_size:,}")

    # create a mapping from characters to integers
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}

    def encode(s):
        return [stoi[c] for c in s]  # encoder: take a string, output a list of integers

    # encode both to integers
    n = len(data)

    # List[int]
    train_data = data[: int(n * 0.9)]
    val_data = data[int(n * 0.9) : int(n * 0.95)]
    test_data = data[int(n * 0.95) :]
    
    train_ids = encode(train_data)
    val_ids = encode(val_data)
    test_ids = encode(test_data)
    
    print(f"train has {len(train_ids):,} tokens")
    print(f"val has {len(val_ids):,} tokens")
    print(f"test has {len(test_ids):,} tokens")

    # export to bin files
    train_ids = np.array(train_ids, dtype=np.uint16)
    val_ids = np.array(val_ids, dtype=np.uint16)
    test_ids = np.array(test_ids, dtype=np.uint16)
    
    train_ids.tofile(data_dir / "train.bin")
    val_ids.tofile(data_dir / "val.bin")
    test_ids.tofile(data_dir / "test.bin")
    print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}")

    # Save the meta information as well, to help us encode/decode later
    meta = {
        "vocab_size": vocab_size,
        "itos": itos,
        "stoi": stoi,
    }
    with open(os.path.join(data_dir / "meta.pkl"), "wb") as f:
        pickle.dump(meta, f)

    print(f"text8 dataset downloaded and prepared in dir {data_dir}")


class Text8Dataset(Dataset):
    def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int):
        """
        seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.
        context is only used for correct preparation of val/test sets.
        """

        self.seq_len = seq_len
        
        self.split = split
        assert self.split in ["train", "val", "test"]
        fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split]

        self.root_dir = pathlib.Path(data_dir)
        data_dir = self.root_dir / "text8"
        if not os.path.exists(data_dir):
            if download:
                prepare_text8(data_dir)
            else:
                raise NotADirectoryError(f"dir {data_dir} does not exist and download is False")

        # memmap() 将磁盘上的大型二进制文件当作内存中的数组进行处理, shape 若未指定, 则返回的数组将是一维的.
        # order 参数指定数组内存布局的顺序, 可以是 C(行优先) 或 F(列优先), 默认是行优先, 这个参数仅在数组大于1维时有效.
        # 还支持 offset 参数, 加载的数组数据从此偏移量开始. 偏移量应该是 dtype 的字节大小的倍数, 默认为 0.
        self.data = np.memmap(data_dir / fname, np.uint16, "r")

    def __getitem__(self, index) -> torch.Tensor:
        seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64))
        return seq

    def __len__(self):
        return self.data.size - self.seq_len

离散化操作

以下是针对连续数据的离散化操作,即将其“分配”至对应的离散区间,然后使用区间中点值来表示,本质上属于一种量化的过程,这也是以下 quantize() 方法的命名原因。

Bayesian Flow Networks(BFN)合集5_数据_19

刚才说到,量化就是将一个连续的浮点值分配至对应的离散区间,然后再用那个区间的中点值来表示。于是,quantize() 方法就是先调用 float_to_idx() 再调用 idx_to_float()。

但是,在调用 quantize() 方法前,由于数据经过了 torchvision.transforms.ToTensor() 的处理,因此数据值位于 [0,1] 区间,于是要先将其 scale 至 [-1,1] 区间内,如下 rgb_image_transform() 的代码所示。

def idx_to_float(idx: np.ndarray, num_bins: int):
    """将离散化区间索引 k 转换为对应的区间中心值 k_c.
    注意, 此处 k 的取值范围与论文中的不同, 论文中 k 的取值范围是 1~K, 而这里:
    k_c = \frac{2k+1}{K} - 1, where k \in [0, K-1]."""
    
    flt_zero_one = (idx + 0.5) / num_bins
    return (2.0 * flt_zero_one) - 1.0


def float_to_idx(flt: np.ndarray, num_bins: int):
    """根据离散化值 k_c 计算出对应的区间索引 k, 是 float_to_idx() 的逆向操作."""
    
    flt_zero_one = (flt / 2.0) + 0.5
    return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()


def quantize(flt, num_bins: int):
    """将浮点值量化以对应的离散化区间中点 k_c 表示, 因此看作是一个量化的过程."""
    return idx_to_float(float_to_idx(flt, num_bins), num_bins)


def rgb_image_transform(x, num_bins=256):
    """将 RGB 图像进行离散化, 其中 x \in [0,1]"""
    return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()

MNIST 的动态二值化

Bayesian Flow Networks(BFN)合集5_建模_20

def bin_mnist_transform(x):
    return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int()


def bin_mnist_cts_transform(x):
    return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5

以上还有个二值化后变为浮点数(-0.5 or 0.5) 的版本,如 bin_mnist_cts_transform() 所示。

八、训练流程

前面讲的是算法实现和数据处理,现在是时候解析将它们串起来的整个训练流程(https://github.com/nnaisense/bayesian-flow-networks/blob/main/train.py%23L178)了。

import copy
import logging
import math

from collections import defaultdict
from pathlib import Path
from typing import Optional, Tuple

import torch
import neptune

from accelerate import Accelerator
from accelerate.logging import get_logger

from omegaconf import OmegaConf

from rich.logging import RichHandler
from rich.progress import Progress

from torch import nn, optim
from torch.utils.data import DataLoader

from model import BFN
from utils_train import (
    seed_everything, log_cfg,
    checkpoint_training_state,
    init_checkpointing,
    log,
    update_ema,
    ddict,
    make_infinite,
    make_progress_bar, make_config, make_dataloaders, make_bfn,
)


torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    datefmt="[%X]",
    handlers=[RichHandler(rich_tracebacks=True, show_time=False)],
)

logger = get_logger(__name__)


def ddict():
    """Infinite default dict to fake neptune run on non-main processes"""
    return defaultdict(ddict)


def main(cfg):
    acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)

    cfg.training.seed = seed_everything(cfg.training.seed)
    logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True)

    with acc.main_process_first():
        model, dataloaders, optimizer = setup(cfg)
        
    ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None  # EMA on main proc only
    model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"])
    
    # 这个 ddict() 对象是一个无限嵌套的 defaultdict,将其视作假的 neptune run 对象,
    # 用于主进程之外的其它进程,类似一种 placeholder 的角色,而主进程会重新对 run 变量进行赋值,使其成为真正的neptune run 对象。
    run = ddict()
    
    if acc.is_main_process:
        ema.to(acc.device)
        try:
            if cfg.meta.neptune:
                import neptune
                
                run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None)
                run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)
                log_cfg(cfg, run)
        except ImportError:
            logger.info("Did not find neptune installed. Logging will be disabled.")

    train(cfg.training, acc, model, ema, dataloaders, optimizer, run)


if __name__ == "__main__":
    cfg_file = OmegaConf.from_cli()['config_file']
    main(make_config(cfg_file))

作者使用了 OmegaConf(https://github.com/omry/omegaconf) 这个配置管理系统,它支持 YAML 格式的文件来定义配置项,并且可以将多个来源(系统环境变量、命令行参数、文件等)的配置项进行合并。

对于实验数据的管理,作者则使用了 neptune(https://neptune.ai/),它可以对实验数据进行追踪、过滤、分组、排序、可视化等,还支持将这些实验结果分享给多人以便合作。

训练是支持分布式的,依赖于大名鼎鼎的 accelerate(https://github.com/huggingface/accelerate),这年头相信大多数人对它已经很熟悉了(不熟悉我也懒得说了~)。

现在来理一理以上 main() 函数的整个流程:首先对分布式相关的东西进行初始化(实例化 Accelerator);然后设置随机种子(seed_everything());接着设置 model, dataloader, optimizer 老三样(setup())并且将它们用 Accelerator 对象(acc) wrap 起来,以便支持分布式训练;哦,这里还对模型做了指数平均移动 EMA(Exponential Moving Average),相当于额外对模型做动量更新,EMA 主要用在评估(validate)阶段;下一步就是对 neptune 做初始化然后将实验配置项记录在其中,注意仅在主进程(main process)上进行即可;最后就是调用 train() 函数开启真正的训练过程了。

接下来先看看 setup() 函数是如何实例化 model, dataloader 以及 optimizer 对象的。

def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:
    """Create the model, dataloader and optimizer"""
    
    dataloaders = make_dataloaders(cfg)
    
    model = make_bfn(cfg.model)    
    if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"):
        # 区分了 decay 与不 decay 的参数.
        params = model.net.get_optim_groups(cfg.optimizer.weight_decay)
    else:
        params = model.net.parameters()
    
    # Instantiate the optimizer using the hyper-parameters in the config
    optimizer = optim.AdamW(params=params, **cfg.optimizer)
    
    return model, dataloaders, optimizer

可以看到,dataloader 和 model 都是通过调用其它函数来完成实例化的,optimizer 则直接使用 Pytorch 内置的 AdamW。此外,还支持对模型参数是否要进行 weight decay 做了区分(但在作者的实现中并非所有模型都支持 get_optim_groups() 这个方法,只有其实现的 GPT 才支持该方法)。

下面进一步来看看 make_bfn() 方法,它被定义在另外的文件 utils_train.py(https://github.com/nnaisense/bayesian-flow-networks/blob/main/utils_train.py#L153) 里。

import model
import networks
import probability

from networks import adapters


def make_from_cfg(module, cfg, **parameters):
    return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None


def make_bfn(cfg: DictConfig):
    data_adapters = {
        "input_adapter": make_from_cfg(adapters, cfg.input_adapter),
        "output_adapter": make_from_cfg(adapters, cfg.output_adapter),
    }
    
    net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)
    bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)
    distribution_factory = make_from_cfg(probability, cfg.distribution_factory)
    loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)
    bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)
    
    return bfn

哦!原来是这样的招数——需要实例化哪个类,就从定义它的文件里将其取出然后再传入对应参数,所以一开始需要先导入包含各个类定义的模块(文件)。

你以为我下一步要给你看 make_dataloaders() 长什么样?不好意思,你误会了。CW 打算先将其晾一晾,搞个熟成,待风味足够时再好好拿出来分享~

现在先回到刚刚训练流程的文件里,看看真正的训练过程 train() 函数是怎么玩的。

def train(
        cfg,
        accelerator: Accelerator,
        model: BFN,
        ema_model: Optional[nn.Module],
        dataloaders: dict,
        optimizer: optim.Optimizer,
        run: "neptune.Run",
        # run: neptune.Run
):
    is_main = accelerator.is_main_process
    pbar = make_progress_bar(is_main)
    run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
    train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan)
    checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None
    best_val_loss = math.inf

    train_iter = make_infinite(dataloaders["train"])
    model.train()
    
    with pbar:
        for step in range(cfg.start_step, cfg.n_training_steps + 1):
            step_loss = 0.0
            for _ in range(cfg.accumulate):
                with accelerator.accumulate(model):
                    train_batch = next(train_iter)
                    loss = model(train_batch)
                    
                    accelerator.backward(loss)
                    if accelerator.sync_gradients and cfg.grad_clip_norm > 0:
                        accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)
                        
                    optimizer.step()
                    optimizer.zero_grad(set_to_none=True)

                step_loss += loss.item()

            update_ema(ema_model, model, cfg.ema_decay)

            if is_main and (step % cfg.checkpoint_interval == 0):
                checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id)
                run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last"))

            log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, cond=is_main and step % cfg.log_interval == 0)
            log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, cond=is_main)

            if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders:
                val_loss = validate(
                    cfg=cfg,
                    model=model,
                    ema_model=ema_model,
                    val_dataloader=dataloaders["val"],
                    step=step,
                    run=run,
                    pbar=pbar,
                    best_val_loss=best_val_loss,
                    checkpoint_root_dir=checkpoint_root_dir,
                    accelerator=accelerator,
                )
                best_val_loss = min(val_loss, best_val_loss)

            # advance=1 代表任务完成度+1
            pbar.update(train_id, advance=1, loss=loss.item())

在这里,首先作者使用了 rich 库的 Progress(https://rich.readthedocs.io/en/stable/reference/progress.html) 对象来用作进度条的显示(可能是他嫌弃 tqdm 太 low 叭~),而 make_progress() 方法就是实例化这个对象并将其返回;然后设置了 checkpoint 的目录(init_checkpointing()),以便记录训练期间的模型权重,免得白搞一场;接着,他将 dataloder 变成一个无限迭代的生成器(make_infinite()),待到达指定步数 n_training_steps 后,就停止训练;最后就是常规的训练迭代了,包括:dataloder 吐数据、模型吃数据进行预测并计算 loss、反向传播更新权重(可能有梯度累积和裁剪)、更新 EMA(update_ema())、每隔一定周期记录 checkpoint、记录 loss 与当前进度、周期性地对模型效果进行评估(若设置了 EMA 则是拿它来做 validation)。

OK,我知道你们可能好奇 make_infinite() 和 update_ema() 具体是怎么做的,没问题,我现在就 show 出来,它们被定义在另外的 utils_train.py(https://github.com/nnaisense/bayesian-flow-networks/blob/main/utils_train.py%23L115) 文件里。

@torch.no_grad()
def update_ema(ema_model, model, ema_decay):
    if ema_model is not None and ema_decay > 0:
        for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
            # ema_i = ema_decay * ema_{i-1} + (1-ema_decay) * model_param
            ema_param.sub_((1 - ema_decay) * (ema_param - model_param))


def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:
    while True:
        for data in dataloader:
            yield data

make_infinite() 实际就是在 dataloader 循环取数据的外面包了一层 while True 无限循环,而 update_ema() 实质上就是将 EMA 模型与当前 model 的权重做加权求和。

现在来看看对模型效果进行评估的过程。

@torch.no_grad()
def validate(
        cfg,
        model: BFN,
        ema_model: nn.Module,
        val_dataloader: DataLoader,
        step: int,
        run: "neptune.Run",
        pbar: Optional[Progress],
        best_val_loss: float,
        checkpoint_root_dir: Optional[Path],
        accelerator: Accelerator,
) -> float:
    """Evaluate model on validation data and save checkpoint if loss improves."""
    
    dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision]
    
    model_to_eval = ema_model if ema_model is not None else model
    model_to_eval.eval()
    
    pbar = pbar or Progress()
    max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader)
    val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan)

    loss, count = 0.0, 0
    for _ in range(cfg.val_repeats):
        for idx, eval_batch in enumerate(val_dataloader):
            enabled = True if dtype in [torch.float16, torch.bfloat16] else False
            with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
                loss += model_to_eval(eval_batch.to(accelerator.device)).item()
                count += 1
                
            pbar.update(val_id, advance=1, loss=loss / count)
            if (idx + 1) >= max_steps:
                break
            
    loss /= count
    pbar.remove_task(val_id)
    log(run["metrics"]["val"]["loss"], loss, step)

    if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)):
        logger.info(f"loss improved: new value is {loss}")
        
        step_checkpoint_path = checkpoint_root_dir / "best"
        run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()
        checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id)
        run["metrics/best/loss/metric"] = loss
        run["metrics/best/loss/step"] = step

    model.train()
    
    return loss

其实就像是个简化版的训练过程,主要做的事情就是取数据输入到模型中完成预测,然后计算 loss,最后看 loss 比起之前有无好转(会预先记录历史最优 loss),有的话就保存下 checkpoint 和其它重要信息(checkpoint_training_state())。

import json


def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str):
    if checkpoint_dir is None:
        return
    
    logger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}")
    accelerator.save_state(checkpoint_dir)
    
    with open(checkpoint_dir / "info.json", "w") as f:
        json.dump({"step": step, "run_id": run_id}, f)
        
    if ema_model is not None:
        ema_checkpoint_path = checkpoint_dir / "ema_model.pt"
        torch.save(ema_model.state_dict(), ema_checkpoint_path)

九、分布式训练的随机种子

OK,现在来填一填前面埋下的坑 —— make_dataloaders()。之所以先将其腌制下放到最后来享用,是因为 CW 想将其与随机种子的设置放在一起来好好吹下水,先看代码:

def seed_everything(seed: Optional[int]) -> int:
    if seed is None:
        seed = random.randrange(np.iinfo(np.int32).max)
        
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    return seed


def worker_init_function(worker_id: int) -> None:
    """https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""
    
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_generator(seed: int):
    g = torch.Generator()
    g.manual_seed(seed)
    
    return g


def make_dataloaders(cfg: DictConfig):
    train_set, val_set, _ = make_datasets(cfg.data)
    
    dataloaders = {
        "train": DataLoader(
            dataset=train_set,
            worker_init_fn=worker_init_function,
            generator=get_generator(cfg.training.seed),
            **cfg.train_loader,
        ),
        "val": DataLoader(
            dataset=val_set,
            worker_init_fn=worker_init_function,
            generator=get_generator(cfg.training.seed),
            **cfg.val_loader,
        ),
    }
    
    return dataloaders

看起来一切都挺正常,在 seed_everything() 中进行全局的随机种子设定,包括:Python, Numpy 和 Pytroch。然后,在 dataloader 里通过 worker_init_fn 对加载数据的 workers 也设置了额外的 worker seed(每个 worker 由独立的 worker_id 识别,会开启额外的 worker 进程),这是为了让每个 worker 拥有不同的随机性,当存在类似数据增强这种操作时能够使得增强后的数据拥有多样性(即各 worker 对应 augment 后的数据呈现不一样)。

但是!在分布式(多 GPUs)训练的情况下,以上实现并不能真正达成 "每个 worker 拥有不同随机性" 这种效果,而是会使得不同 gpu 上拥有相同 worker_id(取值范围通常是 0 ~ num_workers - 1) 的 worker 都有完全一致的随机种子,从而丧失了真正意义上的随机性。

造成这个 bug 的原因是 worker_init_function() 里的 torch.initial_seed() 取决于 get_generator() 里 Generator 对象的 seed + worker_id,Generator 对象的 seed 又由固定的配置项 cfg.training.seed 指定,于是所有 gpu 上的这个值都一样,从而造成不同 gpu 上相同 worker_id 的 worker 最终得到相同的 worker seed。

既然分析出原因,那么解决办法也很简单——让每个 gpu 里 Generator 对象的 seed 不一样即可,比如像这样:

import dist


def get_generator(seed: int):
    import torch.distributed as dist
    
    rank = dist.get_rank() if dist.is_initialized() else 0
    seed += rank
    
    g = torch.Generator()
    g.manual_seed(seed)
    
    return g

这么做之后,你甚至可以不用定义 worker_init_fn,这点在后面较新版的 Pytorch 中已经支持。

对于这个问题,CW 也向作者提了 issue(顺便刷刷存在感),但作者的解决方法是将每个 gpu 的全局随机种子都设得不一样,如下所示:

def seed_everything(seed: Optional[int]):
    assert seed is not None
    seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

但这种做法会导致的一个现象就是所有 gpu 在一开始随机初始化模型参数时,会得到不同的随机参数值。 不过如果是使用 DDP(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 进行分布式训练,那么就不需要担心。因为在 DDP 的机制下,一开始各个 gpu 的模型参数都会由主卡分配以同步。但是,如果用的不是 DDP 或者希望所有 gpu 在全局都拥有一致的随机性时,这个方法就不适用了。

十、吹水

既然走到了本系列的尾端,出于不舍之情当然得好好吹个水!

再谈与扩散模型的比较

BFN 生成样本是一种迭代过程:首先由一个简单的先验分布开始,将这个分布的参数输入到模型中从而输出数据分布;然后从该分布中采样并加噪,将所得的噪声样本作为观测样本来计算后验以更新先验参数;最后将更新的先验参数再次输入到模型中以输出新的数据分布。就这样不断迭代地更新先验分布和模型的输出分布,待一定步骤后,再从输出分布中采样作为最终生成的样本

可以看出,它与扩散模型类似,也是迭代生成的过程并且使用了噪声。所不同的是,扩散模型在训练时是正向的扩散过程、在推理即采样生成时是反向的去噪过程。而我们的主角 BFN 无论是在训练还是采样生成时都是同一种过程,并且没有显式地进行所谓的去噪。在这里,噪声的角色是作为贝叶斯推理(bayesian inference)的 bridge,制作噪声样本是为了计算后验从而更新先验,而这又恰好达到了去噪的效果,因为先验变得更接近真实数据分布了,可谓是“隐式去噪”

另外,大家也知道,扩散模型在理论上有个强硬的限制——正向过程的扩散步数几乎要达到无穷才能变为纯高斯噪声,从而才能与反向的去噪过程的开头(即纯高斯噪声)完美地对应起來。而 BFN 在根本上则没有这种限制,它在采样生成时的初始先验与训练时初始化的输入分布是一致的(完美对应)。

细说 BFN 的噪声设置

那么,更进一步:在 BFN 的玩法中,为何噪声强度要随着时间增加而减小,即:使用了多个噪声强度,如果仅使用一个噪声强度会怎样?毕竟这样可以省去了设置噪声方案这项麻烦事。

在 BFN 的玩法里,并非直接去定义噪声强度而是定义所谓的精度,作为衡量噪声样本(观测样本)与真实数据的接近程度,代表噪声样本所含的有效数据信息量(密度),于是间接控制了噪声强度,所以我们得从源头——精度出发去分析这个事情。

如本系列第二篇文章所述,精度的设置是出于“有效数据信息能够以恒定速率注入到输入分布中”这一宗旨,从而输入分布中所包含的有效信息量越来越多。这样,模型接收的输入(即输入分布的参数)变得越来靠谱,其对应的输出也会更靠谱。而要贯彻这个宗旨,本质上就是要输入分布的期望熵随时间线性递减。当时在文章中就是基于这个出发点在数学上进行分析,最终推导出精度会随时间递增,从而噪声强度就会递减,这也就是为何要用多个噪声强度的原因。

其实,对于这个问题,就算跳脱数学分析,我们也能看出一些苗头。如果噪声强度不变,即精度不变,那么输入分布中所含的有效数据信息就永远是那么丁点儿,从而模型接收的输入(即输入分布的参数)“含金量”就不高,进而就会导致其输出的质量也就不高了。将模型内部的过程看作是去噪(方才说了,相当于隐式去噪),那么即使它完美地去噪了,还原出来的有效数据也就那么点,避免不了成为“劣质品”。

特点与挑战

CW 认为 BFN 最大也是最亮眼(最容易被看到)的特点就是模型的输入不是数据样本本身,而是数据分布的参数!正是这个优秀的基因,导致其天然地能够在连续型(continuous)输入的基础上愉快地玩转离散型(discrete)数据,而无需施加额外的约束。并且,能够在统一的方法框架下适配图像和语言数据的生成,无需专门针对二者做架构上的修改。另外,由于模型输出的是数据分布,因此能够直接计算似然。

但是,BFN 也面临着诸多挑战与不确定性。在最大层面上来讲,其收敛性、稳定性 以及 泛化性还有待检验。它的计算资源也比起一般的模型更多,因为单次前向过程的同时还需要额外进行贝叶斯推理。另外,输入分布很关键,模型对其依赖性不小。若在面临复杂多样的数据时将先验设置地过于简单,可能会导致最终效果不好。

再细节一些,BFN 的精度(噪声)设置也是件棘手的事情。作者在建模离散数据的实验中,就是发现 accuracy schedule 次优导致效果不佳。最后,CW 一直没挖出来 BFN 在生成多样性方面有什么可取之处(比起其他生成模型),或者,我去骚扰下作者看看叭~

完结撒花

水已吹干,真的该结束了。BFN 的建模方法对于许多朋友来说可能比较难懂,由其是当中涉及的数学推导比较多,我看到外网很多人也表示看不懂 paper,绝望地呼出 "not interesting" 的惨叫.. 也正是因为这样,CW 才决定肝出这个系列,毕竟 BFN 确实属于不无聊的风格。特别是当今满街都是扩散模型像行尸走肉般大肆虐杀,能够有只不一样的东西蹦出来难道不觉得很有意思吗!?