From 4b920f24d35c73814b83d56373dd5c913bb57e49 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 4 Apr 2024 02:07:20 +0800 Subject: [PATCH] back to gradio 4.21 and fix chat --- requirements.txt | 2 +- src/llmtuner/extras/misc.py | 2 +- src/llmtuner/webui/chatter.py | 24 +++++++++++++++--------- src/llmtuner/webui/components/chatbot.py | 14 ++++++++------ src/llmtuner/webui/components/infer.py | 4 ++-- 5 files changed, 27 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3928d28d..1fa5a142 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ datasets>=2.14.3 accelerate>=0.27.2 peft>=0.10.0 trl>=0.8.1 -gradio>=4.0.0 +gradio>=4.0.0,<=4.21.0 scipy einops sentencepiece diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 2093d7ea..49b99eee 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -66,7 +66,7 @@ def check_dependencies() -> None: require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2") require_version("peft>=0.10.0", "To fix: pip install peft>=0.10.0") require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1") - require_version("gradio>=4.0.0", "To fix: pip install gradio>=4.0.0") + require_version("gradio>=4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0") def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 2621bd5e..8c744153 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -92,23 +92,29 @@ class WebChatModel(ChatModel): torch_gc() yield ALERTS["info_unloaded"][lang] - def predict( + def append( self, - chatbot: List[Tuple[str, str]], + chatbot: List[List[Optional[str]]], + messages: Sequence[Dict[str, str]], role: str, query: str, + ) -> Tuple[List[List[Optional[str]]], List[Dict[str, str]], str]: + return chatbot + [[query, None]], messages + [{"role": role, "content": query}], "" + + def stream( + self, + chatbot: List[List[Optional[str]]], messages: Sequence[Dict[str, str]], system: str, tools: str, max_new_tokens: int, top_p: float, temperature: float, - ) -> Generator[Tuple[List[Tuple[str, str]], List[Dict[str, str]]], None, None]: - chatbot.append([query, ""]) - query_messages = messages + [{"role": role, "content": query}] + ) -> Generator[Tuple[List[List[Optional[str]]], List[Dict[str, str]]], None, None]: + chatbot[-1][1] = "" response = "" for new_text in self.stream_chat( - query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text if tools: @@ -120,11 +126,11 @@ class WebChatModel(ChatModel): name, arguments = result arguments = json.loads(arguments) tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) - output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}] + output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}] bot_text = "```json\n" + tool_call + "\n```" else: - output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}] + output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] bot_text = result - chatbot[-1] = [query, bot_text] + chatbot[-1][1] = bot_text yield chatbot, output_messages diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index d7d5bd66..8efd333c 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -35,13 +35,15 @@ def create_chat_box( tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) submit_btn.click( - engine.chatter.predict, - [chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature], + engine.chatter.append, + [chatbot, messages, role, query], + [chatbot, messages, query], + ).then( + engine.chatter.stream, + [chatbot, messages, system, tools, max_new_tokens, top_p, temperature], [chatbot, messages], - show_progress=True, - ).then(lambda: "", outputs=[query]) - - clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True) + ) + clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) return ( chat_box, diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 097ded25..1e56d432 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -25,7 +25,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: input_elems.update({infer_backend}) elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) - chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) + chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False) elem_dict.update(dict(chat_box=chat_box, **chat_elems)) load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( @@ -33,7 +33,7 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: ) unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( - lambda: ([], []), outputs=[chatbot, history] + lambda: ([], []), outputs=[chatbot, messages] ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]) return elem_dict