一、加载与Model中参数不一致的预训练模型

我们在构造好了一个模型后,可能要加载一些训练好的模型参数。举例子如下:

假设 trained.pth 是一个训练好的网络的模型参数存储

model = Net()是我们刚刚生成的一个新模型,我们希望model将trained.pth中的参数加载加载进来,但是model中多了一些trained.pth中不存在的参数,如果使用下面的命令:

state_dict = torch.load('trained.pth')
model.load_state_dict(state_dict)

会报错,说key对应不上,因为model你强人所难,我堂堂trained.pth没有你的那些个零碎玩意,你非要向我索取,我上哪给你弄去。但是model不干,说既然你不能完全满足我的需要,那么你有什么我就拿什么吧,怎么办呢?下面的指令代码就行了:

model.load_state_dict(state_dict, strict=False)

二、复制训练好的模型参数

net_path = 'PatAdaAttn-epoch20.pth'
    checkpoint = torch.load(net_path)
    state_dict = {}
    for k, v in checkpoint['model'].items():
        if 'smoother' not in k:
            state_dict.update({k: v})

此时state_dict已经复制了PatAdaAttn-epoch20.pth中的'model'的参数。

三、保存加载自定义模型

上面保存加载的 ‘PatAdaAttn-epoch20.pth’ 其实一个字典,通常包含如下内容:

1)网络结构:输入尺寸、输出尺寸以及隐藏层信息,以便能够在加载时重建模型。
2)模型的权重参数:包含各网络层训练后的可学习参数,可以在模型实例上调用 state_dict() 方法来获取,比如前面介绍只保存模型权重参数时用到的 model.state_dict()。
3)优化器参数:有时保存模型的参数需要稍后接着训练,那么就必须保存优化器的状态和所其使用的超参数,也是在优化器实例上调用 state_dict() 方法来获取这些参数。
4)其他信息:有时我们需要保存一些其他的信息,比如 epoch,batch_size 等超参数。

知道了这些,那么我们就可以自定义需要保存的内容,比如:

1 # saving a checkpoint assuming the network class named ClassNet
2 checkpoint = {'model': ClassNet(),
3               'model_state_dict': model.state_dict(),
4               'optimizer_state_dict': optimizer.state_dict(),
5               'epoch': epoch}
6 
7 torch.save(checkpoint, 'checkpoint.pkl')

上面的 checkpoint 是个字典,里面有4个键值对,分别表示网络模型的不同信息。

然后我们要加载上面保存的自定义的模型:

1 def load_checkpoint(filepath):
 2     checkpoint = torch.load(filepath)
 3     model = checkpoint['model']  # 提取网络结构
 4     model.load_state_dict(checkpoint['model_state_dict'])  # 加载网络权重参数
 5     optimizer = TheOptimizerClass()
 6     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  # 加载优化器参数
 7     
 8     for parameter in model.parameters():
 9         parameter.requires_grad = False
10     model.eval()
11     
12     return model
13     
14 model = load_checkpoint('checkpoint.pkl')

如果加载模型只是为了进行推理测试,则将每一层的 requires_grad 置为 False,即固定这些权重参数;还需要调用 model.eval() 将模型置为测试模式,主要是将 dropoutbatch normalization 层进行固定,否则模型的预测结果每次都会不同。

如果希望继续训练,则调用 model.train(),以确保网络模型处于训练模式。

state_dict() 也是一个Python字典对象,model.state_dict() 将每一层的可学习参数映射为参数矩阵,其中只包含具有可学习参数的层(卷积层、全连接层等)。

比如下面这个例子:

1 # Define model
 2 class TheModelClass(nn.Module):
 3     def __init__(self):
 4         super(TheModelClass, self).__init__()
 5         self.conv1 = nn.Conv2d(3, 8, 5)
 6         self.bn = nn.BatchNorm2d(8)
 7         self.conv2 = nn.Conv2d(8, 16, 5)
 8         self.pool = nn.MaxPool2d(2, 2)
 9         self.fc1 = nn.Linear(16 * 5 * 5, 120)
10         self.fc2 = nn.Linear(120, 10)
11 
12     def forward(self, x):
13         x = self.pool(F.relu(self.conv1(x)))
14         x = self.bn(x)
15         x = self.pool(F.relu(self.conv2(x)))
16         x = x.view(-1, 16 * 5 * 5)
17         x = F.relu(self.fc1(x))
18         x = F.relu(self.fc2(x))
19         x = self.fc3(x)
20         return x
21     
22     # Initialize model
23     model = TheModelClass()
24 
25     # Initialize optimizer
26     optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
27 
28     print("Model's state_dict:")
29     for param_tensor in model.state_dict():
30         print(param_tensor, "\t", model.state_dict()[param_tensor].size())
31 
32     print("Optimizer's state_dict:")
33     for var_name in optimizer.state_dict():
34         print(var_name, "\t", optimizer.state_dict()[var_name])

输出为:

Model's state_dict:
conv1.weight            torch.Size([8, 3, 5, 5])
conv1.bias              torch.Size([8])
bn.weight               torch.Size([8])
bn.bias                 torch.Size([8])
bn.running_mean         torch.Size([8])
bn.running_var          torch.Size([8])
bn.num_batches_tracked  torch.Size([])
conv2.weight            torch.Size([16, 8, 5, 5])
conv2.bias              torch.Size([16])
fc1.weight              torch.Size([120, 400])
fc1.bias                torch.Size([120])
fc2.weight              torch.Size([10, 120])
fc2.bias                torch.Size([10])
Optimizer's state_dict:
state            {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0,.....]

可以看到 model.state_dict() 保存了卷积层,BatchNorm层和最大池化层的信息;而 optimizer.state_dict() 则保存的优化器的状态和相关的超参数。