BERT推理加速代码

共 2260字,需浏览 5分钟

 ·

2021-08-13 00:25

LightSeq的BERT推理加速代码,大家有需要的可以使用起来了。

实现原理

这里我直接使用预训练好的BERT模型,用户只需要输入一个带有[MASK]标记的句子,就可以自动预测出完整的句子。

例如我输入“巴黎是[MASK]国的首都”,那么模型就会输出“巴黎是法国的首都。”。

LightSeq已经「完美支持了BERT模型的快速推理」,代码近期已经开源:


GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generationgithub.com/bytedance/lightseqgithub.com/bytedance/lightseq


BERT推理使用样例可以参考examples/inference/python目录下的ls_bert.py文件。我们用LightSeq来加速BERT推理试试。

首先需要安装LightSeq和Hugging Face:

pip install lightseq transformers

然后需要将Hugging Face的BERT模型导出为LightSeq支持的HDF5模型格式,运行examples/inference/python目录下的hf_bert_export.py文件即可,运行前将代码的第167-168两行修改为下面这样,指定使用中文版本的BERT预训练模型。

output_lightseq_model_name = "lightseq-bert-base-chinese"
input_huggingface_bert_model = "bert-base-chinese"

然后就会在运行目录下生成一个lightseq-bert-base-chinese.hdf5模型文件,导出就成功啦。

最后使用LightSeq进行推理即可:

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import lightseq.inference as lsi

tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
hf_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese")
hf_model.to("cuda:0")
ls_model = lsi.Bert("lightseq-bert-base-chinese.hdf5", 128)

while True:
raw_text = input("请输入中文句子,要预测的字符用#代替:\n> ")
input_text = raw_text.replace("#", "[MASK]")
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"]
mask = inputs["attention_mask"]

outputs = ls_model.infer(input_ids, mask)
logits = hf_model.cls(torch.Tensor(outputs).to(dtype=torch.float, device="cuda:0"))
output_ids = logits.argmax(axis=2)
res_text = tokenizer.batch_decode(output_ids)

res_text = res_text[0][1:-1].replace(" ", "")
output_text = list(raw_text)
for i in range(len(raw_text)):
if raw_text[i] == "#":
output_text[i] = res_text[i]
print("> " + "".join(output_text))

效果演示

给大家看看效果,运行我写好的代码,我们来看看会输出什么结果:

请输入中文句子,要预测的字符用#代替:
> 巴黎是#国的首都。
> 巴黎是法国的首都。

代码地址


GitHub - bytedance/lightseq: LightSeq: A High Performance Library for Sequence Processing and Generationgithub.com/bytedance/lightseqgithub.com/bytedance/lightseq


就在上周,首位外部贡献者出现了,修复了LightSeq的词嵌入表示的bug。

在这里我们非常欢迎感兴趣的同学来贡献自己的代码,包括但不局限于:修复bug、提供训练和推理样例、支持更多模型结构。


浏览 60
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报