深度学习的一个关键组成部分是多次迭代数据集并执行参数更新。这个过程有时被称为“训练循环”,这个循环通常有很多阶段。SpeechBrain 提供了一个方便的框架来组织训练循环,以称为“大脑”类的类的形式,在speechbrain/core.py
. 在每个配方中,我们对此类进行子类化并覆盖默认实现不执行该特定配方所需的方法。
此类的主要方法是fit()
方法,它接受一组数据并对其进行多次迭代并对模型执行更新。为了使用fit()
,必须在子类中至少定义两个方法:compute_forward()
和compute_objectives()
。这些方法定义了模型的计算以生成预测,以及找到梯度所需的损失项。
这是一个非常简单的例子:
!pip install torchaudio
!pip install speechbrain
import torch
import speechbrain as sb
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
return torch.nn.functional.l1_loss(predictions, batch["target"])
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1))
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]
brain.fit(range(10), data)
只需大约 10 行代码,我们就可以训练一个神经模型。这是可能的,因为训练中烦人的细节,例如设置train()
和eval()
适当或计算和应用梯度,都由 Brain 类处理。更好的是,该过程的每一步都可以通过向子类添加方法来覆盖,因此即使是复杂的训练过程(例如 GAN)也可以在 Brain 类中完成。
在本教程中,我们首先向 Brain 类解释参数,然后fit()
逐步介绍该方法并显示必要时可以覆盖的部分。这两项是理解这个类如何工作的关键!
Brain
类的论据
Brain 类只接受 5 个参数,但每个参数都可能有点复杂,因此我们在这里详细解释它们。相关代码只是__init__
定义:
def __init__(
self,
modules=None,
opt_class=None,
hparams=None,
run_opts=None,
checkpointer=None,
):
modules
第一个参数采用torch模块字典。Brain 类接受这个字典并将其转换为 Torch ModuleDict。这提供了一种方便的方法来将所有参数移动到正确的设备、调用train()
和eval()
,并在必要时将模块包装在适当的分布式包装器中。
opt_class
Brain 类采用 pytorch 优化器的函数定义。选择它作为输入而不是预先构建的 pytorch 优化器的原因是,如果需要,Brain 类会自动处理将模块参数包装在分布式包装器中。这需要在参数传递给优化器构造函数之前发生。
要传递 pytorch 优化器构造函数,可以使用 lambda,如本教程开头的示例所示。然而,更方便的是 SpeechBrain 中大多数食谱使用的选项:使用 HyperPyYAML 定义构造函数。该!name:
标签作用类似于lambda,创建可以用来做优化一个新的构造。
optimizer: !name:torch.optim.Adam
lr: 0.1
当然,有时需要零个或多个优化器。在多个优化器的情况下,init_optimizers
可以重写该方法以单独初始化每个优化器。
hparams
Brain 类算法可能依赖于一组应该易于从外部控制的超参数,该参数接受一个字典,所有使用“点符号”的内部方法都可以访问该字典。一个例子如下:
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
term1 = torch.nn.functional.l1_loss(predictions, batch["target1"])
term2 = torch.nn.functional.mse_loss(predictions, batch["target2"])
return self.hparams.weight1 * term1 + self.hparams.weight2 * term2
hparams = {"weight1": 0.7, "weight2": 0.3}
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
modules={"model": model},
opt_class=lambda x: torch.optim.SGD(x, 0.1),
hparams=hparams,
)
data = [{
"input": torch.rand(10, 10),
"target1": torch.rand(10, 10),
"target2": torch.rand(10, 10),
}]
brain.fit(range(10), data)
run_opts
有大量用于控制fit()
方法执行细节的选项,它们都可以通过此参数传递。一些示例包括启用调试模式、执行设备和分布式执行选项。如需完整列表,请参阅 [添加文档链接]。
checkpointer
争论
最后,如果您将 SpeechBrain 检查指针传递给 Brain 类,则会自动调用几个操作:
- 优化器参数被添加到检查点。
- 在训练开始时,加载最近的检查点并从该点恢复训练。如果训练完成,这只会结束训练步骤并继续进行评估。
- 在训练期间,默认情况下每 15 分钟保存一次检查点(可以使用 中的选项更改或禁用
run_opts
)。 - 在评估开始时,加载“最佳”检查点,由检查点中记录的指标的最低或最高分数确定。
fit()
方法
这个方法做的很多,但实际上只需要大约100行代码,所以通过阅读代码本身就可以理解。我们将其逐节分解并解释每个部分的作用。首先,让我们简要回顾一下这些论点:
def fit(
self,
epoch_counter,
train_set,
valid_set=None,
progressbar=None,
train_loader_kwargs={},
valid_loader_kwargs={},
):
-
epoch_counter
参数采用迭代器,因此,当fit()
被调用时,外循环迭代这个变量。这个论点是与一个EpochCounter
能够存储纪元循环状态的类共同设计的。有了这个论点,我们可以从他们停止的地方重新开始实验。 - 在
train_set
和valid_set
论据采取火炬数据集或的DataLoader将加载所需的训练张量。如果未传递 DataLoader,则将自动构造一个(请参阅下一节)。 - 该
progressbar
参数控制是否tqdm
显示进度条以显示每个时期数据集的进度。 - 该
train_loader_kwargs
和valid_loader_kwargs
被传递到make_dataloader
用于使的DataLoader(见下一节)方法。
fit结构
抛开参数,我们可以开始看看这个方法的结构。这是一个简单的图形,用于显示fit()
. 我们将在本教程的其余部分逐一介绍这些内容。
make_dataloader
该fit()
方法的第一步是确保数据采用合适的迭代格式。无论是train_set
和valid_set
与它们各自的关键字参数传递下去。这是实际的代码:
if not isinstance(train_set, DataLoader):
train_set = self.make_dataloader(
train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
)
if valid_set is not None and not isinstance(valid_set, DataLoader):
valid_set = self.make_dataloader(
valid_set,
stage=sb.Stage.VALID,
ckpt_prefix=None,
**valid_loader_kwargs,
)
默认情况下,此方法处理 DataLoader 创建的潜在复杂性,例如为分布式执行创建 DistributedSampler。与fit()
调用中的所有其他方法一样,这可以通过make_dataloader
在 Brain 的子类定义中创建一个方法来覆盖。
on_fit_start
除了数据加载器之外,在训练开始之前还需要进行一些设置。这是相关的代码:
self.on_fit_start()
if progressbar is None:
progressbar = self.progressbar
该on_fit_start
方法处理了一些重要的事情,最容易通过共享代码来解释:
def on_fit_start(self):
self._compile_jit()
self._wrap_distributed()
self.init_optimizers()
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
基本上,此方法可确保正确准备torch模块,包括 jit 编译、分布式包装以及使用所有相关参数初始化优化器。优化器初始化还将优化器参数添加到检查指针(如果有)。最后,此方法加载最新的检查点,以便在中断时恢复训练。
on_stage_start
下一部分开始 epoch 迭代并准备迭代训练数据。要调整准备工作,可以覆盖该on_stage_start
方法,这将允许创建容器来存储训练统计信息。
for epoch in epoch_counter:
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
self.nonfinite_count = 0
if self.train_sampler is not None and hasattr(
self.train_sampler, "set_epoch"
):
self.train_sampler.set_epoch(epoch)
last_ckpt_time = time.time()
训练循环
本教程中最长的代码块专门用于训练和验证数据循环。然而,他们实际上只做三件重要的事情:
- 调用
fit_batch()
DataLoader 中的每个批次。 - 跟踪平均损失并报告。
- (可选)定期保存检查点,以便可以恢复培训。
这是代码:
enable = progressbar and sb.utils.distributed.if_main_process()
with tqdm(
train_set, initial=self.step, dynamic_ncols=True, disable=not enable,
) as t:
for batch in t:
self.step += 1
loss = self.fit_batch(batch)
self.avg_train_loss = self.update_average(
loss, self.avg_train_loss
)
t.set_postfix(train_loss=self.avg_train_loss)
if self.debug and self.step == self.debug_batches:
break
if (
self.checkpointer is not None
and self.ckpt_interval_minutes > 0
and time.time() - last_ckpt_time
>= self.ckpt_interval_minutes * 60.0
):
run_on_main(self._save_intra_epoch_ckpt)
last_ckpt_time = time.time()
也许最重要的一步是fit_batch(batch)
调用,我们在这里展示了一个修剪过的版本:
def fit_batch(self, batch):
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
return loss.detach().cpu()
此方法调用最重要的拟合方法,compute_forward
并且compute_objectives
必须重写这两个方法才能使用 Brain 类。然后反向传播损失,并在应用更新之前检查梯度的非有限值和过大的范数(默认情况下会自动剪裁大范数)。
on_stage_end
在训练循环结束时,on_stage_end
调用该方法进行潜在的清理操作,例如报告训练统计信息。
self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0
验证循环
与训练循环非常相似,验证循环迭代数据加载器并一次处理一批数据。但是,不是调用fit_batch
此循环调用evaluate_batch
,它不会反向传播梯度或应用任何更新。
if valid_set is not None:
self.on_stage_start(Stage.VALID, epoch)
self.modules.eval()
avg_valid_loss = 0.0
with torch.no_grad():
for batch in tqdm(
valid_set, dynamic_ncols=True, disable=not enable
):
self.step += 1
loss = self.evaluate_batch(batch, stage=Stage.VALID)
avg_valid_loss = self.update_average(
loss, avg_valid_loss
)
if self.debug and self.step == self.debug_batches:
break
on_stage_end
该方法与 train 阶段的方法相同,但这次仅在单个进程上执行,因为该进程通常会涉及写入文件。常见用途包括:更新学习率、保存检查点和记录一个时期的统计数据。
self.step = 0
run_on_main(
self.on_stage_end,
args=[Stage.VALID, avg_valid_loss, epoch],
)
最后一件事是简单检查调试模式,只运行几个周期。
if self.debug and epoch == self.debug_epochs:
break
恭喜,您现在知道该fit()
方法的工作原理,以及为什么它是运行实验的有用工具。训练模型的所有部分都被分解了,烦人的部分都得到了处理,同时通过覆盖 Brain 类的任何部分仍然可以获得完全的灵活性。
evaluate()
方法
此方法以与方法的验证数据大致相同的方式迭代测试数据fit()
,包括对on_stage_start
和 的调用on_stage_end
。调用的另一种方法是on_evaluate_start()
方法,默认情况下它会加载最佳检查点以进行评估。
结论
Brain 类和fit()
方法尤其受到其他流行的用于统计和机器学习的 Python 库的启发,特别是 numpy、scipy、keras 和 PyTorch Lightning。
随着我们添加有关 Brain 类更高级用法的教程,我们将在此处添加指向它们的链接。计划教程的一些示例:
- 使用 Brain 类编写 GAN
- 使用 Brain 类进行分布式训练
- Brain 类的非基于梯度的使用