本文主要讲述以下三个内容:

  • 一、 PyTorch 模型定义的方式
  • 1.1 Sequential
  • 1.2 ModuleList
  • 1.3 ModuleDict
  • 1.4 三种方式的比较
  • 二、PyTorch 修改模型
  • 2.1 修改模型层
  • 2.2 添加外部输入
  • 2.3 添加额外输出
  • 三、PyTorch 模型保存与读取
  • 3.1 模型保存
  • 3.2 模型读取


一、 PyTorch 模型定义的方式

PyTorch 模型定义主要包括两个部分:各个部分的初始化(init)和数据流向定义(forward)。

基于nn.Module(PyTorch 中所有神经网络模块的基类),我们可以通过Sequential, ModuleList 和 ModuleDict 三种方式定义pytorch 模型。

pytorch读取损失_pytorch读取损失

1.1 Sequential

顾名思义,序列(Sequential)这种方式将模型的各个模块像序列一样按顺序串联起来,而模型的前向计算就是将这些模块按添加的顺序逐一计算。

根据层名的不同,排列的时候有两种方式:

  • 直接排列
import torch.nn as nn

net = nn.Sequential(
    nn.Linear(256, 32),
    nn.ReLU(),
    nn.Linear(32, 10),
)

print(net)
Sequential(
  (0): Linear(in_features=256, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
)
  • 使用 OrderDict
import collections
import torch.nn as nn

net2 = nn.Sequential(collections.OrderedDict([
    ('fc1', nn.Linear(256, 32)),
    ('relu', nn.ReLU()),
    ('fc2', nn.Linear(32, 10))
]))

print(net2)
Sequential(
  (fc1): Linear(in_features=256, out_features=32, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=32, out_features=10, bias=True)
)

可以看到,使用 Sequential 定义模型的好处在于简单、易读,同时使用Sequential 定义的模型不需要再写forward,因为顺序已经定义好了。

但使用 Sequential 也会使得模型定义丧失灵活性,比如需要在模型中间加入一个外部输入时就不适合用Sequential 的方式实现。使用时需根据实际需求加以选择。

1.2 ModuleList

ModuleList 接受一个子模块的列表作为输入,类似于list, 可以进行append 和 extend 操作。

net = nn.ModuleList([nn.Linear(256, 32), nn.ReLU()])
net.append(nn.Linear(32, 10))
print(type(net))
print(net[0])
print(net)
<class 'torch.nn.modules.container.ModuleList'>
Linear(in_features=256, out_features=32, bias=True)
ModuleList(
  (0): Linear(in_features=256, out_features=32, bias=True)
  (1): ReLU()
  (2): Linear(in_features=32, out_features=10, bias=True)
)

需要注意的是,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。

ModuleList中元素的先后顺序并不代表其在网络中的真实位置顺序,需要经过forward 函数指定各个层的先后顺序后才算完成了模型的定义。具体实现时用for循环即可完成:

class model(nn.Module):
  def __init__(self, in_dim:int, hidden_units:int, out_dim:int):
    super(model, self).__init__()
    self.modulelist = nn.ModuleList([nn.Linear(in_dim, hidden_units), 
    nn.ReLU(),
    nn.Linear(hidden_units, out_dim)
    ])
    
  def forward(self, x):
    for layer in self.modulelist:
      x = layer(x)
    return x
net3 = model(256, 32, 10)
print(net3)
model(
  (modulelist): ModuleList(
    (0): Linear(in_features=256, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=10, bias=True)
  )
)

1.3 ModuleDict

对应模块为nn.ModuleDict()。

ModuleDict 类似于字典(Dict),作用和 ModuleList 类似,只是 ModuleDict 能够更方便地为神经网络的层添加名称

net = nn.ModuleDict({
    'linear': nn.Linear(256, 32),
    'activation': nn.ReLU(),
})
net['output'] = nn.Linear(32, 10)  #添加
print(net['linear'])
print(net.output)
print(net.activation)
print(net)
Linear(in_features=256, out_features=32, bias=True)
Linear(in_features=32, out_features=10, bias=True)
ReLU()
ModuleDict(
  (linear): Linear(in_features=256, out_features=32, bias=True)
  (activation): ReLU()
  (output): Linear(in_features=32, out_features=10, bias=True)
)

1.4 三种方式的比较

Sequential 内的模块按照顺序排列,要保证相邻层的输入输出大小匹配,forward功能已经写好,可以使代码更加简洁,适用于快速验证结果。另一种使用场景是用 nn.Sequential() 写一个卷积块(block),然后像拼积木一样把不同的 block 拼起来组成整个网络,让代码更加简洁。

如果有某个相同的层重复出现,那么 ModuleList 和 ModuleDict 会更加适用,可以“一行顶多行”。

二、PyTorch 修改模型

除了自己构建模型外,我们还可以对现有的模型进行小修改,以满足我们的任务需求。

2.1 修改模型层

以 pytorch 官方视觉库torchvision预定义好的模型 vgg16 为例,来尝试修改模型层。

先来看一下模型的结构:

import torchvision.models as models
net = models.vgg16()
# net = models.resnet50()
print(net)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以看到这个积木模型非常长,而且最后的全连接层的输出维度为1000(out_features = 1000),这是为了适配 ImageNet 预训练的权重。

但在实际任务中,我们可能并不需要高维度的输出,比如做一个0到9的手写数字识别任务只需要输出10个维度,故我们尝试修改这个模型。

from collections import OrderedDict
import torchvision.models as models
net = models.vgg16()

print('修改前的 net.classifier: ')
print(net.classifier)

classifier2 = nn.Sequential(OrderedDict([
    ('fc1', nn.Linear(25088, 4096)),
    ('relu1', nn.ReLU()),
    ('dropout1', nn.Dropout(0.5)),
    ('fc2', nn.Linear(4096, 512)),
    ('relu2', nn.ReLU()),
    ('dropout2', nn.Dropout(0.5)),
    ('fc3', nn.Linear(512, 10)),
    ('output', nn.Softmax(dim=1))]))

net.classifier = classifier2
print('\n修改后的 net.classifier: ')
print(net.classifier)
修改前的 net.classifier: 
Sequential(
  (0): Linear(in_features=25088, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Dropout(p=0.5, inplace=False)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): ReLU(inplace=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=4096, out_features=1000, bias=True)
)

修改后的 net.classifier: 
Sequential(
  (fc1): Linear(in_features=25088, out_features=4096, bias=True)
  (relu1): ReLU()
  (dropout1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=4096, out_features=512, bias=True)
  (relu2): ReLU()
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc3): Linear(in_features=512, out_features=10, bias=True)
  (output): Softmax(dim=1)
)

2.2 添加外部输入

除了已有的输入之外,我们在训练时有时还需要添加一些额外的补充信息作为输入。

例如图像分类任务中,除了图像外,有时候可能还需要输入一些图像的静态特征变量(如caption)作为补充信息,来尝试提高模型的表现。

添加外部输入的方法为:在 forward 中定义好已有模型的前向传播流程,然后在流程中间增加一个额外的输入变量 add_variable 并修改好其前后的输入输出维度。

这里还是以 torchvision 自带的 vgg16 举个栗子:

import torch

class Model(nn.Module):
    def __init__(self, net):
        super(Model, self).__init__()
        self.net = net  #out_dim = 1000
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc_add = nn.Linear(1001, 10, bias=True)  #1000 + 1 new input variable = 1001
        self.output = nn.Softmax(dim=1)

    def forward(self, x, add_variable):
        x = self.net(x)  #out_dim = 1000
        x = self.relu(x)
        x = self.dropout(x)
        x = torch.cat(x, add_variable.unsqueeze(1), 1)
        x = self.fc_add(x)
        x = self.output(x)

        return x

实现的关键在于 torch.cat 拼接了已有变量和输入的额外变量的tensor。一般情况下,add_variable 是一个单一数值(scalar),此时它的维度为 (batch_size, ),需要通过unsqueeze函数在其第二维补充维数,使其维度与x保持一致,这样才能进行拼接torch.cat操作。

注意添加额外输入变量后全连接层的输入输出维度必须要匹配。

已有模型vgg16的输出 x 为1000维,添加一个额外变量后 x 变为1001维,因此添加额外输入变量后的下一个全连接层fc_add的输入维度为1001。

接下来将net换成已有的模型,就可以用来做10分类任务了。

import torchvision.models as models
net = models.vgg16()

model = Model(net)

注意forward里有两个参数,训练时需要喂给模型两个inputs:

out = model(x, add_variable)

2.3 添加额外输出

在训练过程中,有时可能需要查看中间层变量,或者需要提取中间层变量等,这时就需要添加额外输出了。

添加额外输出比较简单,基本思路就是在forward函数的return里返回多个输出即可。

举个栗子:

import torch

class Model(nn.Module):
    def __init__(self, net):
        super(Model, self).__init__()
        self.net = net  #out_dim = 1000
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc_add = nn.Linear(1001, 10, bias=True)  #1000 + 1 new input variable = 1001
        self.output = nn.Softmax(dim=1)

    def forward(self, x, add_variable):
        x0 = self.net(x)  #out_dim = 1000
        x = self.relu(x0)
        x = self.dropout(x)
        x = torch.cat(x, add_variable.unsqueeze(1), 1)
        x = self.fc_add(x)
        x = self.output(x)

        return x, x0

之后按照之前的步骤实例化模型:

import torchvision.models as models
net = models.vgg16()

model = Model(net)

在训练时注意模型有两个outputs:

out10, out1000 = model(x, add_variable)

三、PyTorch 模型保存与读取

对于一个复杂的模型,我们希望训练好以后都可以直接拿来用,而不用重复训练,浪费太多时间,这个时候就需要用到模型保存与读取的知识了。

3.1 模型保存

PyTorch 存储模型主要有3种格式,分别为 pkl、pt、pth。他们仨都是二进制文件,都支持保存模型结构和权重,使用起来基本没有区别。

PyTorch 模型主要包括两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key是层名,value是权重向量)。模型存储也由此分为两种类型:

  • 存储整个模型(包括结构和权重);
  • 只存储模型权重。
import torchvision.models as models
net = models.resnet34(pretrained=True)

savedir = './model_pretrained.pkl'
savedir2 = './model_pretrained2.pkl'

# 保存整个模型
torch.save(net, savedir)

# 保存模型权重
torch.save(net.state_dict, savedir2)
print('Models are saved. ')
Models are saved.

3.2 模型读取

有保存就有读取,PyTorch 可以用torch.load函数读取加载模型。

model1 = torch.load(savedir)
print(model1)

model2 = torch.load(savedir2)
print(model2)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  ...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

<bound method Module.state_dict of ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  ...
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)>

更多关于单卡和多卡模型存储的知识见参考资料[1]。

参考资料:

[1] Datawhale_深入浅出pytorch

[2] https://zhuanlan.zhihu.com/p/64990232