前几个月一直有不少小伙伴问我要「LightSeq的BERT推理加速代码」,当时内部已经使用了,但是一直没空整理开源。
现在代码终于整理好了,写了一个简单的样例,大家有需要的可以使用起来了。
实现原理
这里我直接使用预训练好的BERT模型,用户只需要输入一个带有[MASK]
标记的句子,就可以自动预测出完整的句子。
例如我输入“巴黎是[MASK]国的首都”,那么模型就会输出“巴黎是法国的首都。”。
LightSeq已经「完美支持了BERT模型的快速推理」,代码近期已经开源:
GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generation
BERT推理使用样例可以参考examples/inference/python
目录下的ls_bert.py
文件。我们用LightSeq来加速BERT推理试试。
首先需要安装LightSeq和Hugging Face:
pip install lightseq transformers
然后需要将Hugging Face的BERT模型导出为LightSeq支持的HDF5模型格式,运行examples/inference/python
目录下的hf_bert_export.py
文件即可,运行前将代码的第167-168两行修改为下面这样,指定使用中文版本的BERT预训练模型。
output_lightseq_model_name = "lightseq-bert-base-chinese"
input_huggingface_bert_model = "bert-base-chinese"
然后就会在运行目录下生成一个lightseq-bert-base-chinese.hdf5
模型文件,导出就成功啦。
最后使用LightSeq进行推理即可:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import lightseq.inference as lsi
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
hf_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
hf_model.to("cuda:0")
ls_model = lsi.Bert("lightseq-bert-base-chinese.hdf5", 128)
while True:
raw_text = input("请输入中文句子,要预测的字符用#代替:\n> ")
input_text = raw_text.replace("#", "[MASK]")
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"]
mask = inputs["attention_mask"]
outputs = ls_model.infer(input_ids, mask)
logits = hf_model.cls(torch.Tensor(outputs).to(dtype=torch.float, device="cuda:0"))
output_ids = logits.argmax(axis=2)
res_text = tokenizer.batch_decode(output_ids)
res_text = res_text[0][1:-1].replace(" ", "")
output_text = list(raw_text)
for i in range(len(raw_text)):
if raw_text[i] == "#":
output_text[i] = res_text[i]
print("> " + "".join(output_text))
效果演示
给大家看看效果,运行我写好的代码,我们来看看会输出什么结果:
请输入中文句子,要预测的字符用#代替:
> 巴黎是#国的首都。
> 巴黎是法国的首都。
代码地址
GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generation
就在上周,首位外部贡献者出现了,修复了LightSeq的词嵌入表示的bug。
在这里我们非常欢迎感兴趣的同学来贡献自己的代码,包括但不局限于:修复bug、提供训练和推理样例、支持更多模型结构。