1、简单粗暴直接上代码
注释:作者基于下图的环境搭建没什么问题 ,还需要modelscope>=1.9.1 ,gradio直接安装最新的
import torch
import gradio as gr
import torch.nn as nn
from modelscope import snapshot_download, Model
model_dir = snapshot_download(
"baichuan-inc/Baichuan2-13B-Chat", revision='v1.0.1')
model = Model.from_pretrained(
model_dir, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
def clear_session():
return []
def predict(input, history):
if history is None:
history = []
model_input = []
for chat in history:
model_input.append({"role": "user", "content": chat[0]})
model_input.append({"role": "assistant", "content": chat[1]})
model_input.append({"role": "user", "content": input})
print(model_input)
response = model(model_input)["response"]
history.append((input, response))
history = history[-20:]
return '', history
block = gr.Blocks()
with block as demo:
gr.Markdown("""<h1><center>Baichuan2-13B-Chat</center></h1>
<center>Baichuan2-13B-Chat为Baichuan2-13B系列模型中对齐后的版本,预训练模型可见Baichuan2-13B-Base</center>
""")
chatbot = gr.Chatbot(label='Baichuan2-13B-Chat')
message = gr.Textbox()
message.submit(predict,
inputs=[message, chatbot],
outputs=[message, chatbot])
with gr.Row():
clear_history = gr.Button("🧹 清除历史对话")
send = gr.Button("🚀 发送")
send.click(predict,
inputs=[message, chatbot],
outputs=[message, chatbot])
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot],
queue=False)
demo.queue().launch(height=800, share=False)
2、说一下配置
本人自己测试的环境是内存32G显存24G,而模型最终运行的显存使用是基于输入的文本长度和输出的长度,若是一直不清除历史,会导致输入过长,显存加载爆炸,这边建议是连续输入几个问题就清除一次历史。
其次就是速度问题,这个就是和自己的显卡有关了,算力是硬伤这个木有办法
效果: