拿别人家的东西 写写啊 不乐意就忍了 还是先道歉 再发 并无什么好处 多谢 除了研究勿扰

pytorch-lightning 是建立在pytorch之上的高层次模型接口。pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow.pytorch-lightning 有以下一些引人注目的功能:

  • 可以不必编写自定义循环,只要指定loss计算方法即可。
  • 可以通过callbacks非常方便地添加CheckPoint参数保存、early_stopping 等功能。
  • 可以非常方便地在单CPU、多CPU、单GPU、多GPU乃至多TPU上训练模型。
  • 可以通过调用torchmetrics库,非常方便地添加Accuracy,AUC,Precision等各种常用评估指标。
  • 可以非常方便地实施多批次梯度累加、半精度混合精度训练、最大batch_size自动搜索等技巧,加快训练过程。
  • 可以非常方便地使用SWA(随机参数平均)、CyclicLR(学习率周期性调度策略)与auto_lr_find(最优学习率发现)等技巧 实现模型涨点。  whaosoft aiot http://143ai.com

一般按照如下方式 安装和 引入 pytorch-lightning 库。

#安装 pip install pytorch-lightning

#引入 import pytorch_lightning as pl 

顾名思义,它可以帮助我们漂亮(pl)地进行深度学习研究。😋😋 You do the research. Lightning will do everything else.⭐️⭐️参考文档:

  • pl_docs: https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html
  • pl_template:https://github.com/PyTorchLightning/deep-learning-project-template
  • torchmetrics: https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html

一,pytorch-lightning的设计哲学

pytorch-lightning 的核心设计哲学是将 深度学习项目中的 研究代码(定义模型) 和 工程代码 (训练模型) 相互分离。用户只需专注于研究代码(pl.LightningModule)的实现,而工程代码借助训练工具类(pl.Trainer)统一实现。更详细地说,深度学习项目代码可以分成如下4部分:

  • 研究代码 (Research code),用户继承LightningModule实现。
  • 工程代码 (Engineering code),用户无需关注通过调用Trainer实现。
  • 非必要代码 (Non-essential research code,logging, etc...),用户通过调用Callbacks实现。
  • 数据 (Data),用户通过torch.utils.data.DataLoader实现,也可以封装成pl.LightningDataModule。

二,pytorch-lightning使用范例

下面我们使用minist图片分类问题为例,演示pytorch-lightning的最佳实践。

1,准备数据

import torch   
from torch import nn   
from torchvision import transforms as T  
from torchvision.datasets import MNIST  
from torch.utils.data import DataLoader,random_split  
import pytorch_lightning as pl   
from torchmetrics import Accuracy
class MNISTDataModule(pl.LightningDataModule):  
    def __init__(self, data_dir: str = "./minist/",   
                 batch_size: int = 32,  
                 num_workers: int =4):  
        super().__init__()  
        self.data_dir = data_dir  
        self.batch_size = batch_size  
        self.num_workers = num_workers  
  
    def setup(self, stage = None):  
        transform = T.Compose([T.ToTensor()])  
        self.ds_test = MNIST(self.data_dir, train=False,transform=transform,download=True)  
        self.ds_predict = MNIST(self.data_dir, train=False,transform=transform,download=True)  
        ds_full = MNIST(self.data_dir, train=True,transform=transform,download=True)  
        self.ds_train, self.ds_val = random_split(ds_full, [55000, 5000])  
  
    def train_dataloader(self):  
        return DataLoader(self.ds_train, batch_size=self.batch_size,  
                          shuffle=True, num_workers=self.num_workers,  
                          pin_memory=True)  
  
    def val_dataloader(self):  
        return DataLoader(self.ds_val, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=True)  
  
    def test_dataloader(self):  
        return DataLoader(self.ds_test, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=True)  
  
    def predict_dataloader(self):  
        return DataLoader(self.ds_predict, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=True)
data_mnist = MNISTDataModule()  
data_mnist.setup()
for features,labels in data_mnist.train_dataloader():  
    print(features.shape)  
    print(labels.shape)  
    break
torch.Size([32, 1, 28, 28])  
torch.Size([32])

2,定义模型

net = nn.Sequential(  
    nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Dropout2d(p = 0.1),  
    nn.AdaptiveMaxPool2d((1,1)),  
    nn.Flatten(),  
    nn.Linear(64,32),  
    nn.ReLU(),  
    nn.Linear(32,10)  
)  
  
class Model(pl.LightningModule):  
      
    def __init__(self,net,learning_rate=1e-3):  
        super().__init__()  
        self.save_hyperparameters()  
        self.net = net  
        self.train_acc = Accuracy()  
        self.val_acc = Accuracy()  
        self.test_acc = Accuracy()   
          
          
    def forward(self,x):  
        x = self.net(x)  
        return x  
      
      
    #定义loss  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    #定义各种metrics  
    def training_step_end(self,outputs):  
        train_acc = self.train_acc(outputs['preds'], outputs['y']).item()      
        self.log("train_acc",train_acc,prog_bar=True)  
        return {"loss":outputs["loss"].mean()}  
      
    #定义optimizer,以及可选的lr_scheduler  
    def configure_optimizers(self):  
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)  
      
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
  
    def validation_step_end(self,outputs):  
        val_acc = self.val_acc(outputs['preds'], outputs['y']).item()      
        self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
        self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)  
      
    def test_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    def test_step_end(self,outputs):  
        test_acc = self.test_acc(outputs['preds'], outputs['y']).item()      
        self.log("test_acc",test_acc,on_epoch=True,on_step=False)  
        self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
      
model = Model(net)  
  
#查看模型大小  
model_size = pl.utilities.memory.get_model_size_mb(model)  
print("model_size = {} M \n".format(model_size))  
model.example_input_array = [features]  
summary = pl.utilities.model_summary.ModelSummary(model,max_depth=-1)  
print(summary)
model_size = 0.218447 M   
  
   | Name      | Type              | Params | In sizes         | Out sizes         
---------------------------------------------------------------------------------------  
0  | net       | Sequential        | 54.0 K | [32, 1, 28, 28]  | [32, 10]          
1  | net.0     | Conv2d            | 320    | [32, 1, 28, 28]  | [32, 32, 26, 26]  
2  | net.1     | MaxPool2d         | 0      | [32, 32, 26, 26] | [32, 32, 13, 13]  
3  | net.2     | Conv2d            | 51.3 K | [32, 32, 13, 13] | [32, 64, 9, 9]    
4  | net.3     | MaxPool2d         | 0      | [32, 64, 9, 9]   | [32, 64, 4, 4]    
5  | net.4     | Dropout2d         | 0      | [32, 64, 4, 4]   | [32, 64, 4, 4]    
6  | net.5     | AdaptiveMaxPool2d | 0      | [32, 64, 4, 4]   | [32, 64, 1, 1]    
7  | net.6     | Flatten           | 0      | [32, 64, 1, 1]   | [32, 64]          
8  | net.7     | Linear            | 2.1 K  | [32, 64]         | [32, 32]          
9  | net.8     | ReLU              | 0      | [32, 32]         | [32, 32]          
10 | net.9     | Linear            | 330    | [32, 32]         | [32, 10]          
11 | train_acc | Accuracy          | 0      | ?                | ?                 
12 | val_acc   | Accuracy          | 0      | ?                | ?                 
13 | test_acc  | Accuracy          | 0      | ?                | ?                 
---------------------------------------------------------------------------------------  
54.0 K    Trainable params  
0         Non-trainable params  
54.0 K    Total params  
0.216     Total estimated model params size (MB)

3,训练模型

pl.seed_everything(1234)  
  
ckpt_callback = pl.callbacks.ModelCheckpoint(  
    monitor='val_loss',  
    save_top_k=1,  
    mode='min'  
)  
early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',  
               patience=3,  
               mode = 'min')  
  
# gpus=0 则使用cpu训练,gpus=1则使用1个gpu训练,gpus=2则使用2个gpu训练,gpus=-1则使用所有gpu训练,  
# gpus=[0,1]则指定使用0号和1号gpu训练, gpus="0,1,2,3"则使用0,1,2,3号gpu训练  
# tpus=1 则使用1个tpu训练  
  
trainer = pl.Trainer(max_epochs=20,     
     #gpus=0, #单CPU模式  
     gpus=0, #单GPU模式  
     #num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(进程)模式  
     #gpus=[0,1,2,3],strategy="dp", #多GPU的DataParallel(速度提升效果一般)  
     #gpus=[0,1,2,3],strategy=“ddp_find_unused_parameters_false" #多GPU的DistributedDataParallel(速度提升效果好)  
     callbacks = [ckpt_callback,early_stopping],  
     profiler="simple")   
  
#断点续训  
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')  
  
#训练模型  
trainer.fit(model,data_mnist)
Epoch 8: 100%  
1876/1876 [01:44<00:00, 17.93it/s, loss=0.0603, v_num=0, train_acc=1.000, val_acc=0.985]

4,评估模型

result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path='best')
--------------------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9966545701026917, 'test_loss': 0.010617421939969063}  
--------------------------------------------------------------------------------
result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path='best')
--------------------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9865999817848206, 'test_loss': 0.042671505361795425}  
--------------------------------------------------------------------------------
result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path='best')
--------------------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.987500011920929, 'test_loss': 0.047178059816360474}  
--------------------------------------------------------------------------------

5,使用模型

data,label = next(iter(data_module.test_dataloader()))  
model.eval()  
prediction = model(data)  
print(prediction)
tensor([[-13.0112,  -2.8257,  -1.8588,  -3.6137,  -0.3307,  -5.4953, -19.7282,  
          15.9651,  -8.0379,  -2.2925],  
        [ -6.0261,  -2.5480,  13.4140,  -5.5701, -10.2049,  -6.4469,  -3.7119,  
          -6.0732,  -6.0826,  -7.7339],  
          ...  
        [-16.7028,  -4.9060,   0.4400,  24.4337, -12.8793,   1.5085, -17.9232,  
          -3.0839,   0.5491,   1.9846],  
        [ -5.0909,  10.1805,  -8.2528,  -9.2240,  -1.8044,  -4.0296,  -8.2297,  
          -3.1828,  -5.9361,  -4.8410]], grad_fn=<AddmmBackward0>)

6,保存模型

最优模型默认保存在 trainer.checkpoint_callback.best_model_path 的目录下,可以直接加载。

print(trainer.checkpoint_callback.best_model_path)  
print(trainer.checkpoint_callback.best_model_score)
lightning_logs/version_10/checkpoints/epoch=8-step=15470.ckpt  
tensor(0.0376, device='cuda:0')
model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)  
trainer_clone = pl.Trainer(max_epochs=3,gpus=1)   
result = trainer_clone.test(model_clone,data_module.test_dataloader())  
print(result)
--------------------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9887999892234802, 'test_loss': 0.03627564385533333}  
--------------------------------------------------------------------------------  
[{'test_acc': 0.9887999892234802, 'test_loss': 0.03627564385533333}]

三,训练加速技巧

下面重点介绍pytorch_lightning 模型训练加速的一些技巧。

  • 1,使用多进程读取数据(num_workers=4)
  • 2,使用锁业内存(pin_memory=True)
  • 3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")
  • 4,使用梯度累加(accumulate_grad_batches=6)
  • 5,使用半精度(precision=16,batch_size=2*batch_size)
  • 6,自动搜索最大batch_size(auto_scale_batch_size='binsearch')

(注:过大的batch_size对模型学习是有害的。)详细原理,可以参考:https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html我们将训练代码封装成如下脚本形式,方便后面测试使用。

%%writefile mnist_cnn.py  
import torch   
from torch import nn   
from argparse import ArgumentParser  
  
import torchvision   
from torchvision import transforms as T  
from torchvision.datasets import MNIST  
from torch.utils.data import DataLoader,random_split  
import pytorch_lightning as pl  
from torchmetrics import Accuracy  
  
#================================================================================  
# 一,准备数据  
#================================================================================  
  
class MNISTDataModule(pl.LightningDataModule):  
    def __init__(self, data_dir: str = "./minist/",   
                 batch_size: int = 32,  
                 num_workers: int =4,  
                 pin_memory:bool =True):  
        super().__init__()  
        self.data_dir = data_dir  
        self.batch_size = batch_size  
        self.num_workers = num_workers  
        self.pin_memory = pin_memory  
  
    def setup(self, stage = None):  
        transform = T.Compose([T.ToTensor()])  
        self.ds_test = MNIST(self.data_dir, download=True,train=False,transform=transform)  
        self.ds_predict = MNIST(self.data_dir, download=True, train=False,transform=transform)  
        ds_full = MNIST(self.data_dir, download=True, train=True,transform=transform)  
        self.ds_train, self.ds_val = random_split(ds_full, [55000, 5000])  
  
    def train_dataloader(self):  
        return DataLoader(self.ds_train, batch_size=self.batch_size,  
                          shuffle=True, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def val_dataloader(self):  
        return DataLoader(self.ds_val, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def test_dataloader(self):  
        return DataLoader(self.ds_test, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def predict_dataloader(self):  
        return DataLoader(self.ds_predict, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
      
    @staticmethod  
    def add_dataset_args(parent_parser):  
        parser = ArgumentParser(parents=[parent_parser], add_help=False)  
        parser.add_argument('--batch_size', type=int, default=32)  
        parser.add_argument('--num_workers', type=int, default=4)  
        parser.add_argument('--pin_memory', type=bool, default=True)  
        return parser  
  
#================================================================================  
# 二,定义模型  
#================================================================================  
  
net = nn.Sequential(  
    nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Dropout2d(p = 0.1),  
    nn.AdaptiveMaxPool2d((1,1)),  
    nn.Flatten(),  
    nn.Linear(64,32),  
    nn.ReLU(),  
    nn.Linear(32,10)  
)  
  
class Model(pl.LightningModule):  
      
    def __init__(self,net,learning_rate=1e-3):  
        super().__init__()  
        self.save_hyperparameters()  
        self.net = net  
        self.train_acc = Accuracy()  
        self.val_acc = Accuracy()  
        self.test_acc = Accuracy()   
          
          
    def forward(self,x):  
        x = self.net(x)  
        return x  
      
      
    #定义loss  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    #定义各种metrics  
    def training_step_end(self,outputs):  
        train_acc = self.train_acc(outputs['preds'], outputs['y']).item()      
        self.log("train_acc",train_acc,prog_bar=True)  
        return {"loss":outputs["loss"].mean()}  
      
    #定义optimizer,以及可选的lr_scheduler  
    def configure_optimizers(self):  
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)  
      
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
  
    def validation_step_end(self,outputs):  
        val_acc = self.val_acc(outputs['preds'], outputs['y']).item()      
        self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
        self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)  
      
    def test_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    def test_step_end(self,outputs):  
        test_acc = self.test_acc(outputs['preds'], outputs['y']).item()      
        self.log("test_acc",test_acc,on_epoch=True,on_step=False)  
        self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
      
    @staticmethod  
    def add_model_args(parent_parser):  
        parser = ArgumentParser(parents=[parent_parser], add_help=False)  
        parser.add_argument('--learning_rate', type=float, default=1e-3)  
        return parser  
      
  
#================================================================================  
# 三,训练模型  
#================================================================================  
      
def main(hparams):  
    pl.seed_everything(1234)  
      
    data_mnist = MNISTDataModule(batch_size=hparams.batch_size,  
                                 num_workers=hparams.num_workers)  
      
    model = Model(net,learning_rate=hparams.learning_rate)  
      
    ckpt_callback = pl.callbacks.ModelCheckpoint(  
        monitor='val_loss',  
        save_top_k=1,  
        mode='min'  
    )  
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',  
                   patience=3,  
                   mode = 'min')  
      
    trainer = pl.Trainer.from_argparse_args(   
        hparams,  
        max_epochs=10,  
          
        callbacks = [ckpt_callback,early_stopping]  
    )   
  
      
      
    if hparams.auto_scale_batch_size is not None:  
        #搜索不发生OOM的最大batch_size  
        max_batch_size = trainer.tuner.scale_batch_size(model,data_mnist,  
                        mode=hparams.auto_scale_batch_size)  
        data_mnist.batch_size = max_batch_size  
          
        #等价于  
        #trainer.tune(model,data_mnist)  
          
      
    #gpus=0, #单CPU模式  
    #gpus=1, #单GPU模式  
    #num_processes=4,strategy="ddp_find_unused_parameters_false", #多CPU(进程)模式  
    #gpus=4,strategy="dp", #多GPU(dp速度提升效果一般)  
    #gpus=4,strategy=“ddp_find_unused_parameters_false" #多GPU(ddp速度提升效果好)  
  
    trainer.fit(model,data_mnist)  
    result = trainer.test(model,data_mnist,ckpt_path='best')  
  
if __name__ == "__main__":  
    parser = ArgumentParser()  
    parser = MNISTDataModule.add_dataset_args(parser)  
    parser = Model.add_model_args(parser)  
    parser = pl.Trainer.add_argparse_args(parser)  
    hparams = parser.parse_args()  
    main(hparams)

1,使用多进程读取数据(num_workers=4)

使用多进程读取数据,可以避免数据加载过程成为性能瓶颈。

  • 单进程读取数据(num_workers=0, gpus=1): 1min 18s
  • 多进程读取数据(num_workers=4, gpus=1): 59.7s
%%time  
#单进程读取数据(num_workers=0)  
!python3 mnist_cnn.py --num_workers=0 --gpus=1
------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9857000112533569, 'test_loss': 0.04885349050164223}  
--------------------------------------------------------------------------------  
  
CPU times: user 4.67 s, sys: 2.14 s, total: 6.81 s  
Wall time: 2min 50s
%%time  
#多进程读取数据(num_workers=4)  
!python3 mnist_cnn.py --num_workers=4 --gpus=1
---------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9764000177383423, 'test_loss': 0.0820135846734047}  
--------------------------------------------------------------------------------  
Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 163.40it/s]  
CPU times: user 1.56 s, sys: 647 ms, total: 2.21 s  
Wall time: 59.7 s

2,使用锁业内存(pin_memory=True)

锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘)因此锁业内存比非锁业内存读写效率更高,copy到GPU上也更快速。当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。

  • 非锁业内存存储数据(pin_memory=False, gpus=1): 1min
  • 锁业内存存储数据(pin_memory=True, gpus=1): 59.5s
%%time  
#非锁业内存存储数据(pin_memory=False)  
!python3 mnist_cnn.py --pin_memory=False --gpus=1
----------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9812999963760376, 'test_loss': 0.06231774762272835}  
--------------------------------------------------------------------------------  
Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 171.69it/s]  
CPU times: user 1.59 s, sys: 619 ms, total: 2.21 s  
Wall time: 1min
%%time  
#锁业内存存储数据(pin_memory=True)  
!python3 mnist_cnn.py --pin_memory=True --gpus=1
---------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9757999777793884, 'test_loss': 0.08017424494028091}  
--------------------------------------------------------------------------------  
Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 174.58it/s]  
CPU times: user 1.54 s, sys: 677 ms, total: 2.22 s  
Wall time: 59.5 s

3,使用加速器(gpus=4,strategy="ddp_find_unused_parameters_false")

pl 可以很方便地应用单CPU、多CPU、单GPU、多GPU乃至多TPU上训练模型。以下几种情况训练耗时统计如下:

  • 单CPU: 2min 17s
  • 单GPU:  59.4 s
  • 4个GPU(dp模式): 1min
  • 4个GPU(ddp模式): 38.9 s

一般情况下,如果是单机多卡,建议使用 ddp模式,因为dp模式需要非常多的data和model传输,非常耗时。

%%time  
#单CPU  
!python3 mnist_cnn.py --gpus=0
-----------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9790999889373779, 'test_loss': 0.07223792374134064}  
--------------------------------------------------------------------------------  
Testing: 100%|████████████████████████████████| 313/313 [00:05<00:00, 55.95it/s]  
CPU times: user 2.67 s, sys: 740 ms, total: 3.41 s  
Wall time: 2min 17s
%%time  
#单GPU  
!python3 mnist_cnn.py --gpus=1
---------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9778000116348267, 'test_loss': 0.06929327547550201}  
--------------------------------------------------------------------------------  
Testing: 100%|███████████████████████████████| 313/313 [00:01<00:00, 171.04it/s]  
CPU times: user 1.83 s, sys: 488 ms, total: 2.32 s  
Wall time: 1min 3s
%%time  
#多GPU,dp模式(为公平比较,batch_size=32*4)  
!python3 mnist_cnn.py --gpus=4 --strategy="dp" --batch_size=128
------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9790999889373779, 'test_loss': 0.06855566054582596}  
--------------------------------------------------------------------------------  
Testing: 100%|██████████████████████████████████| 79/79 [00:02<00:00, 38.55it/s]  
CPU times: user 1.2 s, sys: 553 ms, total: 1.75 s  
Wall time: 1min
%%time  
#多GPU,ddp模式  
!python3 mnist_cnn.py --gpus=4 --strategy="ddp_find_unused_parameters_false"
---------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9732000231742859, 'test_loss': 0.08606339246034622}  
--------------------------------------------------------------------------------  
Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:00, 85.79it/s]  
CPU times: user 784 ms, sys: 387 ms, total: 1.17 s  
Wall time: 38.9 s

4,使用梯度累加(accumulate_grad_batches=6)

梯度累加就是累加多个batch的梯度,然后用累加的梯度更新一次参数,使用梯度累加相当于增大batch_size.由于更新参数的计算量略大于简单梯度求和的计算量(对于大部分优化器而言),使用梯度累加会让速度略有提升。

  • 4个GPU(ddp模式): 38.9 s
  • 4个GPU(ddp模式)+梯度累加: 36.9 s
%%time  
#多GPU,ddp模式, 考虑梯度累加  
!python3 mnist_cnn.py --accumulate_grad_batches=6 --gpus=4 --strategy="ddp_find_unused_parameters_false"
----------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9603000283241272, 'test_loss': 0.1400066614151001}  
--------------------------------------------------------------------------------  
Testing: 100%|██████████████████████████████████| 79/79 [00:00<00:00, 89.10it/s]  
CPU times: user 749 ms, sys: 402 ms, total: 1.15 s  
Wall time: 36.9 s

5,使用半精度(precision=16)

通过precision可以设置 double (64), float (32), bfloat16 ("bf16"), half (16) 精度的训练。默认是float(32) 标准精度,bfloat16 ("bf16")是混合精度。如果选择 half(16) 半精度,并同时增大batch_size为原来2倍, 通常训练速度会提升3倍左右。

%%time   
#半精度  
!python3 mnist_cnn.py --precision=16 --batch_size=64 --gpus=1

6,自动搜索最大batch_size(auto_scale_batch_size="power")

!python3 mnist_cnn.py --auto_scale_batch_size="power"  --gpus=1

四,训练涨分技巧

pytorch_lightning 可以非常容易地支持以下训练涨分技巧:

  • SWA(随机参数平均): 调用pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging实现。
  • CyclicLR(学习率周期性调度策略): 设置 lr_scheduler 为 torch.optim.lr_scheduler.CyclicLR实现。
  • auto_lr_find最优学习率发现: 设置 pl.Trainer(auto_lr_find = True)实现。

参考论文:

  • Cyclical Learning Rates for Training Neural Networks 【https://arxiv.org/pdf/1506.01186.pdf】
  • Averaging Weights Leads to Wider Optima and Better Generalization【https://arxiv.org/abs/1803.05407】

我们将代码整理成如下形式,以便后续测试使用。

%%writefile mnist_cnn.py  
import torch   
from torch import nn   
from argparse import ArgumentParser  
import numpy as np   
  
import torchvision   
from torchvision import transforms as T  
from torchvision.datasets import MNIST  
from torch.utils.data import DataLoader,random_split  
import pytorch_lightning as pl  
from torchmetrics import Accuracy  
  
#================================================================================  
# 一,准备数据  
#================================================================================  
  
class MNISTDataModule(pl.LightningDataModule):  
    def __init__(self, data_dir: str = "./minist/",   
                 batch_size: int = 32,  
                 num_workers: int =4,  
                 pin_memory:bool =True):  
        super().__init__()  
        self.data_dir = data_dir  
        self.batch_size = batch_size  
        self.num_workers = num_workers  
        self.pin_memory = pin_memory  
  
    def setup(self, stage = None):  
        transform = T.Compose([T.ToTensor()])  
        self.ds_test = MNIST(self.data_dir, download=True,train=False,transform=transform)  
        self.ds_predict = MNIST(self.data_dir, download=True, train=False,transform=transform)  
        ds_full = MNIST(self.data_dir, download=True, train=True,transform=transform)  
        ds_train, self.ds_val = random_split(ds_full, [59000, 1000])  
        #为加速训练,随机取10000个  
        indices = np.arange(59000)  
        np.random.shuffle(indices)  
        self.ds_train = torch.utils.data.dataset.Subset(  
            ds_train,indices = indices[:3000])   
  
    def train_dataloader(self):  
        return DataLoader(self.ds_train, batch_size=self.batch_size,  
                          shuffle=True, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def val_dataloader(self):  
        return DataLoader(self.ds_val, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def test_dataloader(self):  
        return DataLoader(self.ds_test, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
  
    def predict_dataloader(self):  
        return DataLoader(self.ds_predict, batch_size=self.batch_size,  
                          shuffle=False, num_workers=self.num_workers,  
                          pin_memory=self.pin_memory)  
      
    @staticmethod  
    def add_dataset_args(parent_parser):  
        parser = ArgumentParser(parents=[parent_parser], add_help=False)  
        parser.add_argument('--batch_size', type=int, default=32)  
        parser.add_argument('--num_workers', type=int, default=8)  
        parser.add_argument('--pin_memory', type=bool, default=True)  
        return parser  
  
#================================================================================  
# 二,定义模型  
#================================================================================  
  
net = nn.Sequential(  
    nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),  
    nn.MaxPool2d(kernel_size = 2,stride = 2),  
    nn.Dropout2d(p = 0.1),  
    nn.AdaptiveMaxPool2d((1,1)),  
    nn.Flatten(),  
    nn.Linear(64,32),  
    nn.ReLU(),  
    nn.Linear(32,10)  
)  
  
class Model(pl.LightningModule):  
      
    def __init__(self,net,  
                 learning_rate=1e-3,  
                 use_CyclicLR = False,  
                 epoch_size=500):  
        super().__init__()  
        self.save_hyperparameters() #自动创建self.hparams  
        self.net = net  
        self.train_acc = Accuracy()  
        self.val_acc = Accuracy()  
        self.test_acc = Accuracy()   
          
          
    def forward(self,x):  
        x = self.net(x)  
        return x  
      
      
    #定义loss  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    #定义各种metrics  
    def training_step_end(self,outputs):  
        train_acc = self.train_acc(outputs['preds'], outputs['y']).item()      
        self.log("train_acc",train_acc,prog_bar=True)  
        return {"loss":outputs["loss"].mean()}  
      
    #定义optimizer,以及可选的lr_scheduler  
    def configure_optimizers(self):  
        optimizer = torch.optim.RMSprop(self.parameters(), lr=self.hparams.learning_rate)  
        if not self.hparams.use_CyclicLR:  
            return optimizer   
  
        max_lr = self.hparams.learning_rate  
        base_lr = max_lr/4.0  
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,  
            base_lr=base_lr,max_lr=max_lr,  
            step_size_up=5*self.hparams.epoch_size,cycle_momentum=False)  
        self.print("set lr = "+str(max_lr))  
          
        return ([optimizer],[scheduler])  
      
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
  
    def validation_step_end(self,outputs):  
        val_acc = self.val_acc(outputs['preds'], outputs['y']).item()      
        self.log("val_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
        self.log("val_acc",val_acc,prog_bar=True,on_epoch=True,on_step=False)  
      
    def test_step(self, batch, batch_idx):  
        x, y = batch  
        preds = self(x)  
        loss = nn.CrossEntropyLoss()(preds,y)  
        return {"loss":loss,"preds":preds.detach(),"y":y.detach()}  
      
    def test_step_end(self,outputs):  
        test_acc = self.test_acc(outputs['preds'], outputs['y']).item()      
        self.log("test_acc",test_acc,on_epoch=True,on_step=False)  
        self.log("test_loss",outputs["loss"].mean(),on_epoch=True,on_step=False)  
      
    @staticmethod  
    def add_model_args(parent_parser):  
        parser = ArgumentParser(parents=[parent_parser], add_help=False)  
        parser.add_argument('--learning_rate', type=float, default=7e-3)  
        parser.add_argument('--use_CyclicLR', type=bool, default=False)  
        return parser  
      
  
#================================================================================  
# 三,训练模型  
#================================================================================  
      
def main(hparams):  
    pl.seed_everything(1234)  
      
    data_mnist = MNISTDataModule(batch_size=hparams.batch_size,  
                                 num_workers=hparams.num_workers)  
    data_mnist.setup()  
    epoch_size = len(data_mnist.ds_train)//data_mnist.batch_size  
      
    model = Model(net,learning_rate=hparams.learning_rate,  
                  use_CyclicLR = hparams.use_CyclicLR,  
                  epoch_size=epoch_size)  
      
    ckpt_callback = pl.callbacks.ModelCheckpoint(  
        monitor='val_acc',  
        save_top_k=3,  
        mode='max'  
    )  
      
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_acc',  
                   patience=16,  
                   mode = 'max')  
    callbacks = [ckpt_callback,early_stopping]  
    if hparams.use_swa:  
        callbacks.append(pl.callbacks.StochasticWeightAveraging())  
          
    trainer = pl.Trainer.from_argparse_args(   
        hparams,  
        max_epochs=1000,  
        callbacks = callbacks)   
  
      
    print("hparams.auto_lr_find=",hparams.auto_lr_find)  
    if hparams.auto_lr_find:  
          
        #搜索学习率范围  
        lr_finder = trainer.tuner.lr_find(model,  
          datamodule = data_mnist,  
          min_lr=1e-08,  
          max_lr=1,  
          num_training=100,  
          mode='exponential',  
          early_stop_threshold=4.0  
          )  
        lr_finder.plot()   
        lr = lr_finder.suggestion()  
        model.hparams.learning_rate = lr   
        print("suggest lr=",lr)  
          
        del model   
          
        hparams.learning_rate = lr  
        model = Model(net,learning_rate=hparams.learning_rate,  
                  use_CyclicLR = hparams.use_CyclicLR,  
                  epoch_size=epoch_size)  
          
        #等价于  
        #trainer.tune(model,data_mnist)  
          
  
    trainer.fit(model,data_mnist)  
    train_result = trainer.test(model,data_mnist.train_dataloader(),ckpt_path='best')  
    val_result = trainer.test(model,data_mnist.val_dataloader(),ckpt_path='best')  
    test_result = trainer.test(model,data_mnist.test_dataloader(),ckpt_path='best')  
      
    print("train_result:\n")  
    print(train_result)  
    print("val_result:\n")  
    print(val_result)  
    print("test_result:\n")  
    print(test_result)  
      
      
  
if __name__ == "__main__":  
    parser = ArgumentParser()  
    parser.add_argument('--use_swa', default=False, type=bool)  
    parser = MNISTDataModule.add_dataset_args(parser)  
    parser = Model.add_model_args(parser)  
    parser = pl.Trainer.add_argparse_args(parser)  
    hparams = parser.parse_args()  
    main(hparams)

1,SWA 随机权重平均 (pl.callbacks.stochastic_weight_avg.StochasticWeightAveraging)

  • 平凡方式训练:test_acc = 0.9581000208854675
  • SWA随机权重:test_acc = 0.963100016117096
#平凡方式训练  
!python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false"
------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9581000208854675, 'test_loss': 0.14859822392463684}  
--------------------------------------------------------------------------------
#使用SWA随机权重  
!python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false" --use_swa=True
-----------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.963100016117096, 'test_loss': 0.18146753311157227}  
--------------------------------------------------------------------------------

2,CyclicLR学习率调度策略(torch.optim.lr_scheduler.CyclicLR)

  • 平凡方式训练:test_acc = 0.9581000208854675
  • SWA随机权重:test_acc = 0.963100016117096
  • SWA随机权重 + CyClicLR学习率调度策略: test_acc = 0.9688000082969666
!python3 mnist_cnn.py --gpus=2 --strategy="ddp_find_unused_parameters_false" --use_swa=True --use_CyclicLR=True
------------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9688000082969666, 'test_loss': 0.11470437049865723}  
--------------------------------------------------------------------------------

3, 最优学习率搜索(auto_lr_find=True)

  • 平凡方式训练:test_acc = 0.9581000208854675
  • SWA随机权重:test_acc = 0.963100016117096
  • SWA随机权重 + CyClicLR学习率调度策略: test_acc = 0.9688000082969666
  • SWA随机权重 + CyClicLR学习率调度策略 + 最优学习率搜索:test_acc = 0.9693999886512756
!python3 mnist_cnn.py --gpus=1  --auto_lr_find=True --use_swa=True --use_CyclicLR=True
---------------------------------------------------------------  
DATALOADER:0 TEST RESULTS  
{'test_acc': 0.9693999886512756, 'test_loss': 0.11024412512779236}  
--------------------------------------------------------------------------------  
Testing: 100%|███████████████████████████████| 313/313 [00:02<00:00, 137.85it/s]