背景

使用pytorch加载huggingface下载的albert-base-chinede模型出错

Exception has occurred: OSError
Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

模型地址:​​https://huggingface.co/models?search=albert_chinese​

方法一:

参考以下文章删除缓存目录,问题还是存在​

​​

​​​​

方法二:

使用另一台电脑加载相同模型,加载成功,查看两台电脑的torch、transformers版本,发现一个torch为1.1,另一个为torch1.7.x

参考pytorch官网,torch1.6之后修改了模型保存方式,高版本保存的模型,低版本无法加载

The 1.6 release of PyTorch switched torch.save to use a new zipfile-based file format. torch.load still retains the ability to load files in the old format. If for any reason you want torch.save to use the old format, pass the kwarg _use_new_zipfile_serialization=False.

解决方法:


  1. 升级torch为高版本
  2. 如果因为cuda兼容等问题无法升级,可以在高版本上加载模型,然后重新save并添加_use_new_zipfile_serialization=False

from transformers import *
import torch

pretrained = 'D:/07_data/albert_base_chinese'
tokenizer = BertTokenizer.from_pretrained(pretrained)
model = AlbertForMaskedLM.from_pretrained(pretrained)

# 它包装在PyTorch DistributedDataParallel或DataParallel中
model_to_save = model.module if hasattr(model, 'module') else model

torch.save(model_to_save.state_dict(), 'pytorch_model_unzip.bin', _use_new_zipfile_serialization=False)

其他保存方法请参考:

​​

​​​​

时间会记录下一切。