网络结构和参数可以分开的保存和加载,因此,pytorch保存模型有两种方法:
- 保存 整个模型 (结构+参数)
- 只保存模型参数(官方推荐)
# 保存整个网络torch.save(model, checkpoint_path) # 保存网络中的参数, 速度快,占空间少torch.save(model.state_dict(),checkpoint_path)#--------------------------------------------------#针对上面一般的保存方法,加载的方法分别是:model_dict=torch.load(checkpoint_path)model_dict=model.load_state_dict(torch.load(checkpoint_path))
- 注意到,两者都是用torch.save(obj, dir)实现,这个函数的作用是将对象保存到磁盘中,它的内部是使用Python的pickle实现
- PyTorch约定使用.pt或.pth后缀命名保存文件
- 两种方法的区别其实就是obj参数的不同:前者的obj是整个model对象,后者的obj是从model对象里获取存储了model参数的词典,推荐用第二种,虽然麻烦了一丁点,但是比较灵活,有利于实现预训练、参数迁移等操作
一般加载模型是在训练完成后用模型做测试,这时候加载模型记得要加上model.eval(),把模型切换到evaluation模式,这时候会调整dropout和bactch的模式。
- 网络结构及其参数的保存与加载:load整个模型,完成了模型的定义和参数的加载这两个过程
torch.save(model,'model.pth') # 保存model = torch.load("model.pth") # 加载
- 只保存/加载模型参数:需要先创建一个网络模型,然后再load_state_dict()
torch.save(model.state_dict(),"model.pth") # 保存参数model = model() # 代码中创建网络结构params = torch.load("model.pth") # 加载参数model.load_state_dict(params) # 应用到网络结构中
重点介绍一下这种方法,一般训完一个模型之后不会只保存一个模型的参数,为了方便后续操作,比如恢复训练、参数迁移等,会保存当前状态的一个快照,格式以字典的格式存储,具体信息可以根据自己的需要,下面列出几个方面:
- 模型参数(不带模型的结构)
- 优化器参数
- loss
- epoch
- args
把这些信息用字典包装起来,然后保存即可。这种方式保存的只是参数,所以,在加载时需要先创建好模型,然后再把参数加载进去,如下:
# 获得保存信息save_data = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, 'epoch': epoch, 'args': args ...}# 保存torch.save(save_data , path)# 加载参数model_CKPT = torch.load(path)model = Mymodel()optimizer = Myoptimizer()model.load_state_dict(model_CKPT ['model_state_dict'])optimizer.load_state_dict(model_CKPT ['optimizer_state_dict'])...# 若对于加载参数,用函数表示,比如:def load_checkpoint(model, checkpoint_path, optimizer): if checkpoint_path != None: model_CKPT = torch.load(checkpoint_path) model.load_state_dict(model_CKPT['state_dict']) print('loading checkpoint!') optimizer.load_state_dict(model_CKPT['optimizer']) return model, optimizer
但是,对于已经保存好的模型参数,我们可能修改了一部分网络结构,比如加了一些,删除一些等等,那么需要过滤这些参数,加载方式如下:
def load_checkpoint(model, checkpoint_path, optimizer, loadOptimizer): if checkpoint_path != 'None': print("loading checkpoint...") model_dict = model.state_dict()# 修改后的模型随机初始化的参数 modelCheckpoint = torch.load(checkpoint_path) # 修改前的模型参数 pretrained_dict = modelCheckpoint['model_state_dict'] # 过滤操作 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} # 获取修改后模型所需参数 model_dict.update(new_dict) # 打印出来,更新了多少参数 print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))# 修改后模型加载所需的,已经训练好的参数 model.load_state_dict(model_dict) print("loaded finished!") # 如果不需要更新优化器那么设置为false if loadOptimizer == True: optimizer.load_state_dict(modelCheckpoint['optimizer_state_dict']) print('loaded! optimizer') else: print('not loaded optimizer') else: print('No checkpoint_path is included') return model, optimizer