深度学习的一个关键组成部分是多次迭代数据集并执行参数更新。这个过程有时被称为“训练循环”,这个循环通常有很多阶段。SpeechBrain 提供了一个方便的框架来组织训练循环,以称为“大脑”类的类的形式,在speechbrain/core.py. 在每个配方中,我们对此类进行子类化并覆盖默认实现不执行该特定配方所需的方法。

此类的主要方法是fit()方法,它接受一组数据并对其进行多次迭代并对模型执行更新。为了使用fit(),必须在子类中至少定义两个方法:compute_forward()compute_objectives()。这些方法定义了模型的计算以生成预测,以及找到梯度所需的损失项。

这是一个非常简单的例子:

!pip install torchaudio
!pip install speechbrain
import torch
import speechbrain as sb

class SimpleBrain(sb.Brain):
  def compute_forward(self, batch, stage):
    return self.modules.model(batch["input"])

    
  def compute_objectives(self, predictions, batch, stage):
    return torch.nn.functional.l1_loss(predictions, batch["target"])

model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1))
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]
brain.fit(range(10), data)

只需大约 10 行代码,我们就可以训练一个神经模型。这是可能的,因为训练中烦人的细节,例如设置train()eval()适当或计算和应用梯度,都由 Brain 类处理。更好的是,该过程的每一步都可以通过向子类添加方法来覆盖,因此即使是复杂的训练过程(例如 GAN)也可以在 Brain 类中完成。

在本教程中,我们首先向 Brain 类解释参数,然后fit()逐步介绍该方法并显示必要时可以覆盖的部分。这两项是理解这个类如何工作的关键!

Brain类的论据

Brain 类只接受 5 个参数,但每个参数都可能有点复杂,因此我们在这里详细解释它们。相关代码只是__init__定义:

def __init__(
    self,
    modules=None,
    opt_class=None,
    hparams=None,
    run_opts=None,
    checkpointer=None,
):

modules 

第一个参数采用torch模块字典。Brain 类接受这个字典并将其转换为 Torch ModuleDict。这提供了一种方便的方法来将所有参数移动到正确的设备、调用train()eval(),并在必要时将模块包装在适当的分布式包装器中。

opt_class 

Brain 类采用 pytorch 优化器的函数定义。选择它作为输入而不是预先构建的 pytorch 优化器的原因是,如果需要,Brain 类会自动处理将模块参数包装在分布式包装器中。这需要在参数传递给优化器构造函数之前发生。

要传递 pytorch 优化器构造函数,可以使用 lambda,如本教程开头的示例所示。然而,更方便的是 SpeechBrain 中大多数食谱使用的选项:使用 HyperPyYAML 定义构造函数。该!name:标签作用类似于lambda,创建可以用来做优化一个新的构造。

optimizer: !name:torch.optim.Adam
    lr: 0.1

当然,有时需要零个或多个优化器。在多个优化器的情况下,init_optimizers可以重写该方法以单独初始化每个优化器。

hparams 

Brain 类算法可能依赖于一组应该易于从外部控制的超参数,该参数接受一个字典,所有使用“点符号”的内部方法都可以访问该字典。一个例子如下:

class SimpleBrain(sb.Brain):
  def compute_forward(self, batch, stage):
    return self.modules.model(batch["input"])
    
  def compute_objectives(self, predictions, batch, stage):
    term1 = torch.nn.functional.l1_loss(predictions, batch["target1"])
    term2 = torch.nn.functional.mse_loss(predictions, batch["target2"])
    return self.hparams.weight1 * term1 + self.hparams.weight2 * term2

hparams = {"weight1": 0.7, "weight2": 0.3}
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
  modules={"model": model},
  opt_class=lambda x: torch.optim.SGD(x, 0.1),
  hparams=hparams,
)
data = [{
  "input": torch.rand(10, 10),
  "target1": torch.rand(10, 10),
  "target2": torch.rand(10, 10),
}]
brain.fit(range(10), data)

run_opts 

有大量用于控制fit()方法执行细节的选项,它们都可以通过此参数传递。一些示例包括启用调试模式、执行设备和分布式执行选项。如需完整列表,请参阅 [添加文档链接]。

checkpointer 争论

最后,如果您将 SpeechBrain 检查指针传递给 Brain 类,则会自动调用几个操作:

  1. 优化器参数被添加到检查点。
  2. 在训练开始时,加载最近的检查点并从该点恢复训练。如果训练完成,这只会结束训练步骤并继续进行评估。
  3. 在训练期间,默认情况下每 15 分钟保存一次检查点(可以使用 中的选项更改或禁用run_opts)。
  4. 在评估开始时,加载“最佳”检查点,由检查点中记录的指标的最低或最高分数确定。

fit()方法

这个方法做的很多,但实际上只需要大约100行代码,所以通过阅读代码本身就可以理解。我们将其逐节分解并解释每个部分的作用。首先,让我们简要回顾一下这些论点:

def fit(
    self,
    epoch_counter,
    train_set,
    valid_set=None,
    progressbar=None,
    train_loader_kwargs={},
    valid_loader_kwargs={},
):

 

  1. epoch_counter参数采用迭代器,因此,当fit()被调用时,外循环迭代这个变量。这个论点是与一个EpochCounter能够存储纪元循环状态的类共同设计的。有了这个论点,我们可以从他们停止的地方重新开始实验。
  2. train_setvalid_set论据采取火炬数据集或的DataLoader将加载所需的训练张量。如果未传递 DataLoader,则将自动构造一个(请参阅下一节)。
  3. progressbar参数控制是否tqdm显示进度条以显示每个时期数据集的进度。
  4. train_loader_kwargsvalid_loader_kwargs被传递到make_dataloader用于使的DataLoader(见下一节)方法。

fit结构

抛开参数,我们可以开始看看这个方法的结构。这是一个简单的图形,用于显示fit(). 我们将在本教程的其余部分逐一介绍这些内容。

make_dataloader

fit()方法的第一步是确保数据采用合适的迭代格式。无论是train_setvalid_set与它们各自的关键字参数传递下去。这是实际的代码:

if not isinstance(train_set, DataLoader):
    train_set = self.make_dataloader(
        train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
    )
if valid_set is not None and not isinstance(valid_set, DataLoader):
    valid_set = self.make_dataloader(
        valid_set,
        stage=sb.Stage.VALID,
        ckpt_prefix=None,
        **valid_loader_kwargs,
    )

默认情况下,此方法处理 DataLoader 创建的潜在复杂性,例如为分布式执行创建 DistributedSampler。与fit()调用中的所有其他方法一样,这可以通过make_dataloader在 Brain 的子类定义中创建一个方法来覆盖。

on_fit_start

除了数据加载器之外,在训练开始之前还需要进行一些设置。这是相关的代码:

self.on_fit_start()

if progressbar is None:
    progressbar = self.progressbar

on_fit_start方法处理了一些重要的事情,最容易通过共享代码来解释:

def on_fit_start(self):
    self._compile_jit()
    self._wrap_distributed()
    self.init_optimizers()
    if self.checkpointer is not None:
        self.checkpointer.recover_if_possible(
            device=torch.device(self.device)
        )

基本上,此方法可确保正确准备torch模块,包括 jit 编译、分布式包装以及使用所有相关参数初始化优化器。优化器初始化还将优化器参数添加到检查指针(如果有)。最后,此方法加载最新的检查点,以便在中断时恢复训练。

on_stage_start

下一部分开始 epoch 迭代并准备迭代训练数据。要调整准备工作,可以覆盖该on_stage_start方法,这将允许创建容器来存储训练统计信息。

for epoch in epoch_counter:
    self.on_stage_start(Stage.TRAIN, epoch)
    self.modules.train()
    self.nonfinite_count = 0
    if self.train_sampler is not None and hasattr(
        self.train_sampler, "set_epoch"
    ):
        self.train_sampler.set_epoch(epoch)
    last_ckpt_time = time.time()

训练循环

本教程中最长的代码块专门用于训练和验证数据循环。然而,他们实际上只做三件重要的事情:

  1. 调用fit_batch()DataLoader 中的每个批次。
  2. 跟踪平均损失并报告。
  3. (可选)定期保存检查点,以便可以恢复培训。

这是代码:

enable = progressbar and sb.utils.distributed.if_main_process()
with tqdm(
    train_set, initial=self.step, dynamic_ncols=True, disable=not enable,
) as t:
    for batch in t:
        self.step += 1
        loss = self.fit_batch(batch)
        self.avg_train_loss = self.update_average(
            loss, self.avg_train_loss
        )
        t.set_postfix(train_loss=self.avg_train_loss)

        if self.debug and self.step == self.debug_batches:
            break

        if (
            self.checkpointer is not None
            and self.ckpt_interval_minutes > 0
            and time.time() - last_ckpt_time
            >= self.ckpt_interval_minutes * 60.0
        ):
            run_on_main(self._save_intra_epoch_ckpt)
            last_ckpt_time = time.time()

也许最重要的一步是fit_batch(batch)调用,我们在这里展示了一个修剪过的版本:

def fit_batch(self, batch):
    outputs = self.compute_forward(batch, Stage.TRAIN)
    loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
    loss.backward()
    if self.check_gradients(loss):
        self.optimizer.step()
    self.optimizer.zero_grad()
    return loss.detach().cpu()

 此方法调用最重要的拟合方法,compute_forward并且compute_objectives必须重写这两个方法才能使用 Brain 类。然后反向传播损失,并在应用更新之前检查梯度的非有限值和过大的范数(默认情况下会自动剪裁大范数)。

on_stage_end

在训练循环结束时,on_stage_end调用该方法进行潜在的清理操作,例如报告训练统计信息。

self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0

验证循环

与训练循环非常相似,验证循环迭代数据加载器并一次处理一批数据。但是,不是调用fit_batch此循环调用evaluate_batch,它不会反向传播梯度或应用任何更新。

if valid_set is not None:
    self.on_stage_start(Stage.VALID, epoch)
    self.modules.eval()
    avg_valid_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(
            valid_set, dynamic_ncols=True, disable=not enable
        ):
            self.step += 1
            loss = self.evaluate_batch(batch, stage=Stage.VALID)
            avg_valid_loss = self.update_average(
                loss, avg_valid_loss
            )

            if self.debug and self.step == self.debug_batches:
                break

on_stage_end

该方法与 train 阶段的方法相同,但这次仅在单个进程上执行,因为该进程通常会涉及写入文件。常见用途包括:更新学习率、保存检查点和记录一个时期的统计数据。

self.step = 0
run_on_main(
    self.on_stage_end,
    args=[Stage.VALID, avg_valid_loss, epoch],
)

最后一件事是简单检查调试模式,只运行几个周期。

if self.debug and epoch == self.debug_epochs:
    break

恭喜,您现在知道该fit()方法的工作原理,以及为什么它是运行实验的有用工具。训练模型的所有部分都被分解了,烦人的部分都得到了处理,同时通过覆盖 Brain 类的任何部分仍然可以获得完全的灵活性。

evaluate()方法

此方法以与方法的验证数据大致相同的方式迭代测试数据fit(),包括对on_stage_start和 的调用on_stage_end。调用的另一种方法是on_evaluate_start()方法,默认情况下它会加载最佳检查点以进行评估。

结论

Brain 类和fit()方法尤其受到其他流行的用于统计和机器学习的 Python 库的启发,特别是 numpy、scipy、keras 和 PyTorch Lightning。

随着我们添加有关 Brain 类更高级用法的教程,我们将在此处添加指向它们的链接。计划教程的一些示例:

  • 使用 Brain 类编写 GAN
  • 使用 Brain 类进行分布式训练
  • Brain 类的非基于梯度的使用