PyTorch已经足够简单易用,但是简单易用不等于方便快捷。特别是做大量实验的时候,很多东西都会变得复杂,代码也会变得庞大,这时候就容易出错。
针对这个问题,就有了PyTorch Lightning。它可以重构你的PyTorch代码,抽出复杂重复部分,让你专注于核心的构建,让你的实验更快速更便捷地开展迭代。
1. Lightning 简约哲学
大部分的DL/ML代码都可以分为以下这三部分:
- 研究代码 Research code
- 工程代码 Engineering code
- 非必要代码 Non-essential code
1.1 研究代码 Research code
这部分属于模型(神经网络)部分,一般处理模型的结构、训练等定制化部分。
在Linghtning中,这部分代码抽象为 LightningModule 类。
1.2 工程代码 Engineering code
这部分代码很重要的特点是:重复性强,比如说设置early stopping、16位精度、GPUs分布训练。
在Linghtning中,这部分抽象为 Trainer 类。
1.3 非必要代码 Non-essential code
这部分代码有利于实验的进行,但是和实验没有直接关系,甚至可以不使用。比如说检查梯度、给tensorboard输出log。
在Linghtning中,这部分抽象为 Callbacks 类。
2. 典型的AI研究项目
在大多数研究项目中,研究代码 通常可以归纳到以下关键部分:
- 模型
- 训练/验证/测试 数据
- 优化器
- 训练/验证/测试 计算
上面已经提到,研究代码 在 Lightning 中,是抽象为 LightningModule 类;而这个类与我们平时使用的 torch.nn.Module 是一样的(在原有代码中直接替换 Module 而不改其他代码也是可以的),但不同的是,Lightning 围绕 torch.nn.Module 做了很多功能性的补充,把上面4个关键部分都囊括了进来。
这么做的意义在于:我们的 研究代码 都是围绕 我们的神经网络模型 来运行的,所以 Lightning 把这部分代码都集合在一个类里。
所以我们接下来的介绍,都是围绕 LightningModule 类来展开。
3. 生命周期
为了让大家先有一个总体的概念,在这里我先让大家清楚 LightningModule 中运行的生命流程。
以下所有的函数,都是在 LightningModule 类 里。
这部分是训练开始之后的执行 “一般(默认)顺序”。
- 首先是准备工作,包括初始化 LightningModule,准备数据 和 配置优化器。
这部分代码 只执行一次。
1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (准备数据,包括下载数据、预处理等等)
3. `configure_optimizers()` (配置优化器)
- 测试 “验证代码”。
提前来做的意义在于:不需要等待漫长的训练过程才发现验证代码有错。
这部分就是提前执行 “验证代码”,所以和下面的验证部分是一样的。
1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()`
- 开始加载dataloader,用来给训练加载数据
1. `train_dataloader()`
2. `val_dataloader()` (如果你定义了)
- 下面部分就是循环训练了,
_step()
的意思就是按batch来进行的部分;_epoch_end()
就是所有batch执行完后要进行的部分。
# 循环训练与验证
1. `training_step()`
2. `validation_step()`
3. `validation_epoch_end()`
- 最后训练完了,就要进行测试,但测试部分需要手动调用
.test()
,这是为了避免误操作。
# 测试(需要手动调用)
1. `test_dataloader()`
2. `test_step()`
3. `test_epoch_end()`
在这里,我们很容易总结出,在训练部分,主要是三部分:_dataloader
/_step
/_epoch_end
。Lightning把训练的三部分抽象成三个函数,而我们只需要“填鸭式”地补充这三部分,就可以完成模型训练部分代码的编写。
为了让大家更清晰地了解这三部分的具体位置,下面用 PyTorch实现方式 来展现其位置。
for epoch in epochs:
for batch in train_dataloader:
# train_step
# ....
# train_step
loss.backward()
optimizer.step()
optimizer.zero_grad()
for batch in val_dataloader:
# validation_step
# ....
# validation_step
# *_step_end
# ....
# *_step_end
4. 使用Lightning的好处
- 只需要专注于 研究代码
不需要写一大堆的 .cuda()
和 .to(device)
,Lightning会帮你自动处理。如果要新建一个tensor,可以使用type_as
来使得新tensor处于相同的处理器上。
def training_step(self, batch, batch_idx):
x, y = batch
# 把z放在和x一样的处理器上
z = sample_noise()
z = z.type_as(x)
在这里,有个地方需要注意的是,不是所有的在LightningModule 的 tensor 都会被自动处理,而是只有从 Dataloader 里获取的 tensor 才会被自动处理,所以对于 transductive learning 的训练,最好自己写Dataloader的处理函数。
- 工程代码参数化
平时我们写模型训练的时候,这部分代码会不断重复,但又不得不做,不如说ealy stopping,精度的调整,显存内存之间的数据转移。这部分代码虽然不难,但减少这部分代码会使得 研究代码 更加清晰,整体也更加简洁。
下面是简单的展示,表示使用 LightningModule 建立好模型后,如何进行训练。
model = LightningModuleClass()
trainer = pl.Trainer(gpus="0", # 用来配置使用什么GPU
precision=32, # 用来配置使用什么精度,默认是32
max_epochs=200 # 迭代次数
)
trainer.fit(model) # 开始训练
trainer.test() # 训练完之后测试
结语
以上就是我对于 PyTorch Lightning 的入门总结,自己在这里也走了很多坑,也把官方文档过了一遍,但我的目的不是仿照官方文档翻译一遍,而是希望有自己的实践体会和相对于官方文档的规范更直观。