系列文章目录

tensor运算小结



文章目录

  • 系列文章目录
  • 前言
  • 方法一:模型参数
  • 1. 模型参数存储
  • 2. 模型参数加载
  • 方法二:模型本身
  • 1. 模型存储
  • 2. 读入模型
  • 3. 注意事项
  • 总结



前言

在多人合作、模型训练耗时、模型需要部署并运用于生产等情景下,需要将模型结果存储固定并重新加载,出于快速、前后结果的一致性等方面的考虑。
那如何进行模型存储并重新使用呢?本文通过以下两种方法实现PyTorch框架下模型在本地环境的存储和加载重用。


方法一:模型参数

1. 模型参数存储

假设已经有了训练好的模型,此处用 trained_model 代替。
方法一使用pkl格式的文件对参数进行存储。也可以是用pt、pth格式进行存储。

import torch
# trained_model 此处为之前训练好的模型
torch.save(trained_model.state_dict(), 'model_parameter.pkl')
# torch.save(trained_model.state_dict(), 'model_parameter.pt')

2. 模型参数加载

这种方法必须先初始化模型对应的结构,然后再加载参数。需要是被储存模型的模型结构,否则会出现错误。

  1. 实例化之前训练时的模型结构。此处随意定义一个模型结构

如下

import torch.nn as nn

# 此网络只是一个实例
class Model(nn.Module):#继承nn.Module
    def __init__(self,in_features,out_features):
        super().__init__()
        self.linear1 = nn.Linear(in_features, 5, bias = True)
        self.linear2 = nn.Linear(5, out_features, bias = True)
        self.relu=nn.ReLU()
        self.sig=nn.Sigmoid()
    
    def forward(self, x):
        s=self.linear1(x)
        s=self.relu(s)
        s=self.linear2(s)
        s=self.sig(s)
        return s
        
# 实例化一个模型结构
parameter_model = Model()
  1. 加载模型参数。
# 对实例的模型,加载模型参数
parameter_model.load_state_dict(torch.load('model_parameter.pkl'))
# parameter_model.load_state_dict(torch.load('model_parameter.pt'))

方法二:模型本身

1. 模型存储

import torch
# trained_model 为训练好的模型
model_path = "./model/output/model_self.pkl"  
# 定义一个模型储存的位置和文件名称为 model_path
torch.save(trained_model, model_path)

2. 读入模型

reload_model = torch.load(model_path)

3. 注意事项

由于我是用notebook来运行python代码,在使用方法二时发现了一个问题,就是当在训练模型过程中当前文件夹(简称文件夹A)下使用到了别的模块(引用了py文件C:例如定义网络结构的文件),而加载时在另一个文件夹下(简称文件夹B)无模块C,则加载会报错。需要将文件C复制到文件夹B下即可使用。


总结

本文总结了本地存储并加载PyTorch模型的两种常用方法,主要是对前段时间工作的知识的总结,并供自己后续复制使用。希望能对大家提供帮助,谢谢观看。