110 lines
4.6 KiB
Diff
110 lines
4.6 KiB
Diff
diff --git a/web_demo.py b/web_demo.py
|
||
index 1af24c9..8c0e765 100644
|
||
--- a/web_demo.py
|
||
+++ b/web_demo.py
|
||
@@ -1,14 +1,23 @@
|
||
-from transformers import AutoModel, AutoTokenizer
|
||
+import json
|
||
+import argparse
|
||
+import requests
|
||
+from transformers import AutoTokenizer
|
||
import gradio as gr
|
||
import mdtex2html
|
||
-from utils import load_model_on_gpus
|
||
|
||
-tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||
-model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda()
|
||
-# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
|
||
-# from utils import load_model_on_gpus
|
||
-# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
|
||
-model = model.eval()
|
||
+def parse_args():
|
||
+ parser = argparse.ArgumentParser(description="ChatGLM2-6B/ChatGLM3-6b web demo")
|
||
+ parser.add_argument("--model_path", type=str, required=True, help="The path to model weights")
|
||
+ parser.add_argument("--mindie_sever_ip", type=str, default="127.0.0.1", help="The IP address of mindie server")
|
||
+ parser.add_argument("--mindie_sever_port", type=int, default=1025, help="The port of mindie server")
|
||
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="Max new tokens to generate")
|
||
+ parser.add_argument("--concurrency", type=int, default=10, help="Concurrency count of web demo")
|
||
+
|
||
+ return parser.parse_args()
|
||
+
|
||
+
|
||
+args = parse_args()
|
||
+tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||
|
||
"""Override Chatbot.postprocess"""
|
||
|
||
@@ -71,6 +80,49 @@ def predict(input, chatbot, max_length, top_p, temperature, history, past_key_va
|
||
yield chatbot, history, past_key_values
|
||
|
||
|
||
+def build_inputs(tokenizer, query: str):
|
||
+ # history由服务化内部自行处理
|
||
+ prompt = tokenizer.build_prompt(query, history=None)
|
||
+ return prompt
|
||
+
|
||
+
|
||
+def request(input, chatbot, max_length, top_p, temperature, history, past_key_values):
|
||
+ chatbot.append((parse_text(input), ""))
|
||
+
|
||
+ # 添加prompt格式以支持chat
|
||
+ promt = build_inputs(tokenizer, input)
|
||
+
|
||
+ response = requests.post(
|
||
+ f"http://{args.mindie_sever_ip}:{args.mindie_sever_port}/generate_stream",
|
||
+ json={
|
||
+ "inputs": promt,
|
||
+ "parameters": {
|
||
+ "max_new_tokens": max_length,
|
||
+ "do_sample": True,
|
||
+ "repetition_penalty": 1.05,
|
||
+ "seed": None,
|
||
+ "temperature": temperature,
|
||
+ # "top_k": 1,
|
||
+ "top_p": top_p,
|
||
+ "batch_size": 1
|
||
+ },
|
||
+ },
|
||
+ verify=False, stream=True
|
||
+ )
|
||
+
|
||
+ generate_text = ""
|
||
+ for line in response.iter_lines():
|
||
+ if not line:
|
||
+ continue
|
||
+ # 删除字符串开头的'data: '
|
||
+ res = line.decode('utf-8')[6:]
|
||
+ # 获取流式生成的文本内容
|
||
+ res_text = json.loads(res).get('token').get('text')
|
||
+ generate_text += res_text
|
||
+ chatbot[-1] = (parse_text(input), parse_text(generate_text))
|
||
+ yield chatbot, history, past_key_values
|
||
+
|
||
+
|
||
def reset_user_input():
|
||
return gr.update(value='')
|
||
|
||
@@ -92,17 +144,17 @@ with gr.Blocks() as demo:
|
||
submitBtn = gr.Button("Submit", variant="primary")
|
||
with gr.Column(scale=1):
|
||
emptyBtn = gr.Button("Clear History")
|
||
- max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
|
||
- top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
|
||
- temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
||
+ max_length = gr.Slider(1, args.max_new_tokens, value=args.max_new_tokens, step=1.0, label="Maximum New Tokens", interactive=True)
|
||
+ top_p = gr.Slider(0.01, 0.99, value=0.01, step=0.01, label="Top P", interactive=True)
|
||
+ temperature = gr.Slider(0.01, 1, value=0.01, step=0.01, label="Temperature", interactive=True)
|
||
|
||
history = gr.State([])
|
||
past_key_values = gr.State(None)
|
||
|
||
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
|
||
+ submitBtn.click(request, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
|
||
[chatbot, history, past_key_values], show_progress=True)
|
||
submitBtn.click(reset_user_input, [], [user_input])
|
||
|
||
emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)
|
||
|
||
-demo.queue().launch(share=False, inbrowser=True)
|
||
+demo.queue(concurrency_count=args.concurrency).launch(server_name='0.0.0.0', share=False, inbrowser=True)
|