forked from p04798526/LLaMA-Factory-Mirror
add tool hint
This commit is contained in:
parent
487dee066f
commit
2abfe5fbc2
|
@ -1,6 +1,9 @@
|
|||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
@ -29,6 +32,8 @@ def create_chat_box(
|
|||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
tools.input(check_json_schema, [tools])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
||||
|
|
|
@ -41,6 +41,13 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def check_json_schema(text: str) -> None:
|
||||
try:
|
||||
json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
gr.Warning("Invalid JSON schema")
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
|
|
Loading…
Reference in New Issue