torch.save()和torch.load():

torch.save()和torch.load()配合使用,
分别用来保存一个对象(任何对象,
不一定要是PyTorch中的对象)到文件,和从文件中加载一个对象.
加载的时候可以指明是否需要数据在CPU和GPU中相互移动.

Module.state_dict()和Module.load_state_dict():

Module.state_dict()返回一个字典,
该字典以键值对的方式保存了Module的整个状态.

Module.load_state_dict()可以从一个字典中加载参数到这个module和其后代,
如果strict是True,
那么所加载的字典和该module本身state_dict()方法返回的关键字必须严格确切的匹配上.
If strict is True, 
then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
返回值是一个命名元组:
NamedTuple with missing_keys and unexpected_keys fields,
分别保存缺失的关键字和未预料到的关键字.
如果自己的模型跟预训练模型只有部分层是相同的,
那么可以只加载这部分相同的参数,
只要设置strict参数为False来忽略那些没有匹配到的keys即可。
# 方式1:
# model_path = 'model_name.pth'
# model_params_path = 'params_name.pth'
# ----保存----
# torch.save(model, model_path)
# ----加载----
# model = torch.load(model_path)


# 方式2:
#----保存----
# torch.save(model.state_dict(), model_params_path) #保存的文件名后缀一般是.pt或.pth
#----加载----
# model=Model().cuda() #定义模型结构
# model.load_state_dict(torch.load(model_params_path))  #加载模型参数

说明:

# 保存/加载整个模型
torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
这种保存/加载模型的过程使用了最直观的语法,
所用代码量少。这使用Python的pickle保存所有模块。
这种方法的缺点是,保存模型的时候,
序列化的数据被绑定到了特定的类和确切的目录。
这是因为pickle不保存模型类本身,而是保存这个类的路径,
并且在加载的时候会使用。因此,
当在其他项目里使用或者重构的时候,加载模型的时候会出错。




# 保存/加载 state_dict(推荐)
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

自己选择要保存的参数,设置checkpoint:

#----保存----
torch.save({
    'epoch': epoch + 1,
    'arch': args.arch,
    'state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
  	'loss': loss,
    'best_prec1': best_prec1,}, 
    'checkpoint_name.tar' )

#----加载----
checkpoint = torch.load('checkpoint_name.tar')

#按关键字获取保存的参数
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
state_dict=checkpoint['state_dict']

model=Model()#定义模型结构
model.load_state_dict(state_dict)

保存多个模型到同一个文件:

#----保存----
torch.save({
  'modelA_state_dict': modelA.state_dict(),
  'modelB_state_dict': modelB.state_dict(),
  'optimizerA_state_dict': optimizerA.state_dict(),
  'optimizerB_state_dict': optimizerB.state_dict(),
  ...
  }, PATH)

#----加载----
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelAClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict']
modelB.load_state_dict(checkpoint['modelB_state_dict']
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']

modelA.eval()
modelB.eval()
# or
modelA.train()
modelB.train()
# 在这里,保存完模型后加载的时候有时会
# 遇到CUDA out of memory的问题,
# 我google到的解决方法是加上map_location=‘cpu’

checkpoint = torch.load(PATH,map_location='cpu')

加载预训练模型的部分:

resnet152 = models.resnet152(pretrained=True) #加载模型结构和参数
pretrained_dict = resnet152.state_dict()
"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
   也可以直接从官方model_zoo下载:
   pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

或者写详细一点:

model_dict = model.state_dict()
state_dict = {}
for k, v in pretrained_dict.items():
    if k in model_dict.keys():
        # state_dict.setdefault(k, v)
        state_dict[k] = v
    else:
        print("Missing key(s) in state_dict :{}".format(k))
model_dict.update(state_dict)
model.load_state_dict(model_dict)