文章目录
- 网络结构一致
- 网络定义及读入预训练模型
- 网络结构的匹配与模型参数载入
- 网络结构不一致
- Reference
网络结构一致
这里模型使用Pytorch提供的ResNet50作为backbone,预训练模型的backbone也为ResNet50。
网络定义及读入预训练模型
利用pytorch提供的网络定义模型的backbone,设置不读取pytorch官方提供的ImagNet预训练模型,然后根据自己的任务设置backbone最后一层的参数。
import torch
import torchvision.models
from torch import nn
from collections import OrderedDict
model = torchvision.models.resnet50(pretrained=False) # Pytorch提供的网络结构,不加载官方预训练模型(ImageNet)
fc_features = model.fc.in_features # 提取fc层中固定的参数
model.fc = nn.Linear(fc_features, 400) # 修改为自己项目的类别数量(也即预训练模型的类别数)
读入自己任务需要的预训练模型(自己提前下载)。
# 读入自己需要的预训练模型
pthfile = 'tf_model_zoo/tsn2d_kinetics400_rgb_r50_seg3_f1s1-b702e12f.pth' # ResNet50,Kinetics400
pretrained_model = torch.load(pthfile)
网络结构的匹配与模型参数载入
通过输出网络模型或者调试查看网络属性,观察到所定义的模型与预训练模型的网络层名不一致,比如第一层名字分别为:conv1.weight
和backbone.conv1.weight
,那么使预训练模型的网络层名与定义模型的网络层名一致只需要去掉前缀就行了。
其中,如果需要全连接层的参数,使其也保持一致就可以了,同时也有两种载入预训练模型参数的方式,代码如下:
# 更改预训练模型的层名,使其匹配pytorch的ResNet50网络层名
new_state_dict = OrderedDict()
for k, v in pretrained_model['state_dict'].items():
name = k[9:] # remove `backbone.`
if name == 'fc_cls.weight': # 全连接层的参数也匹配
name = 'fc.weight'
new_state_dict[name] = v
if name == 'fc_cls.bias':
name = 'fc.bias'
new_state_dict[name] = v
new_state_dict[name] = v
model_dict = model.state_dict() # 查看ResNet50 backbone的初始参数
model.load_state_dict(new_state_dict, strict=True) # 载入预训练模型参数,严格匹配key的名字和数量
# 第二种预训练模型载入方式
# model_dict.update(new_state_dict) # 更新ResNet50 backbone的初始参数
# model.load_state_dict(model_dict) # 载入更新后的参数
print(model.state_dict())
网络结构不一致
这里模型使用Pytorch提供的ResNet101作为backbone,而预训练模型的backbone为ResNet50。
import torch
import torchvision.models
from torch import nn
from collections import OrderedDict
model = torchvision.models.resnet101(pretrained=False) # Pytorch提供的网络结构,不加载官方预训练模型(ImageNet)
fc_features = model.fc.in_features # 提取fc层中固定的参数
model.fc = nn.Linear(fc_features, 51) # 修改为自己项目的类别数量(也即预训练模型的类别数)
与之前同理,首先观察网络层名有什么差异,然后更改层名使其一致,这里由于网络结构不同,所以只将ResNet101与ResNet50中都有的层名用预训练权重参数赋值。代码如下:
# 更改预训练模型的层名,使其匹配models.py中定义的base_model网络层名
new_state_dict = OrderedDict()
for k, v in pretrained_model['state_dict'].items():
name = 'module.base_model.' + k[9:] # 更改conv1.weight等层名的前缀
new_state_dict[name] = v
pretrained_dict = {k: v for k, v in new_state_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
Reference
https://zhuanlan.zhihu.com/p/84797438