From 2abfe5fbc2f79a87c741aadebf11dd2fce8670a2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 18 Jan 2024 13:19:09 +0800 Subject: [PATCH] add tool hint --- src/llmtuner/webui/components/chatbot.py | 5 +++++ src/llmtuner/webui/utils.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index aa087536..ebc1b71f 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -1,6 +1,9 @@ import gradio as gr from typing import TYPE_CHECKING, Dict, Optional, Tuple +from ..utils import check_json_schema + + if TYPE_CHECKING: from gradio.blocks import Block from gradio.components import Component @@ -29,6 +32,8 @@ def create_chat_box( top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) + tools.input(check_json_schema, [tools]) + submit_btn.click( engine.chatter.predict, [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature], diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index c273b635..6bd80093 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -41,6 +41,13 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: return gr.update(interactive=True) +def check_json_schema(text: str) -> None: + try: + json.loads(text) + except json.JSONDecodeError: + gr.Warning("Invalid JSON schema") + + def gen_cmd(args: Dict[str, Any]) -> str: args.pop("disable_tqdm", None) args["plot_loss"] = args.get("do_train", None)