LLaMA-Factory-310P3/mindie/examples/models/chatglm/v2_6b/web_demo.patch

110 lines
4.6 KiB
Diff
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

diff --git a/web_demo.py b/web_demo.py
index 1af24c9..8c0e765 100644
--- a/web_demo.py
+++ b/web_demo.py
@@ -1,14 +1,23 @@
-from transformers import AutoModel, AutoTokenizer
+import json
+import argparse
+import requests
+from transformers import AutoTokenizer
import gradio as gr
import mdtex2html
-from utils import load_model_on_gpus
-tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
-model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
-# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
-# from utils import load_model_on_gpus
-# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
-model = model.eval()
+def parse_args():
+ parser = argparse.ArgumentParser(description="ChatGLM2-6B/ChatGLM3-6b web demo")
+ parser.add_argument("--model_path", type=str, required=True, help="The path to model weights")
+ parser.add_argument("--mindie_sever_ip", type=str, default="127.0.0.1", help="The IP address of mindie server")
+ parser.add_argument("--mindie_sever_port", type=int, default=1025, help="The port of mindie server")
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="Max new tokens to generate")
+ parser.add_argument("--concurrency", type=int, default=10, help="Concurrency count of web demo")
+
+ return parser.parse_args()
+
+
+args = parse_args()
+tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
"""Override Chatbot.postprocess"""
@@ -71,6 +80,49 @@ def predict(input, chatbot, max_length, top_p, temperature, history, past_key_va
yield chatbot, history, past_key_values
+def build_inputs(tokenizer, query: str):
+ # history由服务化内部自行处理
+ prompt = tokenizer.build_prompt(query, history=None)
+ return prompt
+
+
+def request(input, chatbot, max_length, top_p, temperature, history, past_key_values):
+ chatbot.append((parse_text(input), ""))
+
+ # 添加prompt格式以支持chat
+ promt = build_inputs(tokenizer, input)
+
+ response = requests.post(
+ f"http://{args.mindie_sever_ip}:{args.mindie_sever_port}/generate_stream",
+ json={
+ "inputs": promt,
+ "parameters": {
+ "max_new_tokens": max_length,
+ "do_sample": True,
+ "repetition_penalty": 1.05,
+ "seed": None,
+ "temperature": temperature,
+ # "top_k": 1,
+ "top_p": top_p,
+ "batch_size": 1
+ },
+ },
+ verify=False, stream=True
+ )
+
+ generate_text = ""
+ for line in response.iter_lines():
+ if not line:
+ continue
+ # 删除字符串开头的'data: '
+ res = line.decode('utf-8')[6:]
+ # 获取流式生成的文本内容
+ res_text = json.loads(res).get('token').get('text')
+ generate_text += res_text
+ chatbot[-1] = (parse_text(input), parse_text(generate_text))
+ yield chatbot, history, past_key_values
+
+
def reset_user_input():
return gr.update(value='')
@@ -92,17 +144,17 @@ with gr.Blocks() as demo:
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
- max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
- top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
- temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
+ max_length = gr.Slider(1, args.max_new_tokens, value=args.max_new_tokens, step=1.0, label="Maximum New Tokens", interactive=True)
+ top_p = gr.Slider(0.01, 0.99, value=0.01, step=0.01, label="Top P", interactive=True)
+ temperature = gr.Slider(0.01, 1, value=0.01, step=0.01, label="Temperature", interactive=True)
history = gr.State([])
past_key_values = gr.State(None)
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
+ submitBtn.click(request, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
[chatbot, history, past_key_values], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
-demo.queue().launch(share=False, inbrowser=True)
+demo.queue(concurrency_count=args.concurrency).launch(server_name='0.0.0.0', share=False, inbrowser=True)