网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法:

  1. 保存 整个模型 (结构+参数)
  2. 只保存模型参数(官方推荐)
# 保存整个网络torch.save(model, checkpoint_path) # 保存网络中的参数, 速度快,占空间少torch.save(model.state_dict(),checkpoint_path)#--------------------------------------------------#针对上面一般的保存方法,加载的方法分别是:model_dict=torch.load(checkpoint_path)model_dict=model.load_state_dict(torch.load(checkpoint_path))
  • 注意到,两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现
  • PyTorch约定使用.pt或.pth后缀命名保存文件
  • 两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model对象里获取存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作

一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。

  • 网络结构及其参数的保存与加载:load整个模型,完成了模型的定义和参数的加载这两个过程
torch.save(model,'model.pth') # 保存model = torch.load("model.pth") # 加载
  • 只保存/加载模型参数:需要先创建一个网络模型,然后再load_state_dict()
torch.save(model.state_dict(),"model.pth") # 保存参数model = model() # 代码中创建网络结构params = torch.load("model.pth") # 加载参数model.load_state_dict(params) # 应用到网络结构中

重点介绍一下这种方法,一般训完一个模型之后不会只保存一个模型的参数,为了方便后续操作,比如恢复训练、参数迁移等,会保存当前状态的一个快照,格式以字典的格式存储,具体信息可以根据自己的需要,下面列出几个方面:

  • 模型参数(不带模型的结构)
  • 优化器参数
  • loss
  • epoch
  • args

把这些信息用字典包装起来,然后保存即可。这种方式保存的只是参数,所以,在加载时需要先创建好模型,然后再把参数加载进去,如下:

# 获得保存信息save_data = {    'model_state_dict': model.state_dict(),    'optimizer_state_dict': optimizer.state_dict(),    'loss': loss,    'epoch': epoch,    'args': args     ...}# 保存torch.save(save_data , path)# 加载参数model_CKPT = torch.load(path)model = Mymodel()optimizer = Myoptimizer()model.load_state_dict(model_CKPT ['model_state_dict'])optimizer.load_state_dict(model_CKPT ['optimizer_state_dict'])...# 若对于加载参数,用函数表示,比如:def load_checkpoint(model, checkpoint_path, optimizer):    if checkpoint_path != None:        model_CKPT = torch.load(checkpoint_path)        model.load_state_dict(model_CKPT['state_dict'])        print('loading checkpoint!')        optimizer.load_state_dict(model_CKPT['optimizer'])    return model, optimizer

但是,对于已经保存好的模型参数,我们可能修改了一部分网络结构,比如加了一些,删除一些等等,那么需要过滤这些参数,加载方式如下:

def load_checkpoint(model, checkpoint_path, optimizer, loadOptimizer):    if checkpoint_path != 'None':        print("loading checkpoint...")        model_dict = model.state_dict()# 修改后的模型随机初始化的参数        modelCheckpoint = torch.load(checkpoint_path) # 修改前的模型参数        pretrained_dict = modelCheckpoint['model_state_dict']        # 过滤操作        new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}        # 获取修改后模型所需参数        model_dict.update(new_dict)        # 打印出来,更新了多少参数        print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))# 修改后模型加载所需的,已经训练好的参数        model.load_state_dict(model_dict)        print("loaded finished!")        # 如果不需要更新优化器那么设置为false        if loadOptimizer == True:            optimizer.load_state_dict(modelCheckpoint['optimizer_state_dict'])            print('loaded! optimizer')        else:            print('not loaded optimizer')    else:        print('No checkpoint_path is included')    return model, optimizer