upgrade gradio to 4.21.0
This commit is contained in:
parent
a0333bb0ce
commit
831c5321ac
|
@ -4,7 +4,7 @@ datasets>=2.14.3
|
|||
accelerate>=0.27.2
|
||||
peft>=0.10.0
|
||||
trl>=0.8.1
|
||||
gradio>=3.38.0,<4.0.0
|
||||
gradio>4.0.0,<=4.21.0
|
||||
scipy
|
||||
einops
|
||||
sentencepiece
|
||||
|
|
|
@ -2,8 +2,7 @@ from llmtuner import Evaluator
|
|||
|
||||
|
||||
def main():
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
Evaluator().eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -36,7 +36,7 @@ class WebChatModel(ChatModel):
|
|||
return self.engine is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
if self.loaded:
|
||||
|
@ -80,7 +80,7 @@ class WebChatModel(ChatModel):
|
|||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
|
||||
if self.demo_mode:
|
||||
gr.Warning(ALERTS["err_demo"][lang])
|
||||
|
@ -97,13 +97,13 @@ class WebChatModel(ChatModel):
|
|||
chatbot: List[Tuple[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Dict[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": role, "content": query}]
|
||||
response = ""
|
||||
|
@ -126,12 +126,5 @@ class WebChatModel(ChatModel):
|
|||
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
chatbot[-1] = [query, bot_text]
|
||||
yield chatbot, output_messages
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
|
|
|
@ -79,9 +79,9 @@ def get_template(model_name: str) -> str:
|
|||
return "default"
|
||||
|
||||
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type not in PEFT_METHODS:
|
||||
return gr.update(value=[], choices=[], interactive=False)
|
||||
return gr.Dropdown(value=[], choices=[], interactive=False)
|
||||
|
||||
adapters = []
|
||||
if model_name and finetuning_type == "lora":
|
||||
|
@ -92,7 +92,7 @@ def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
|||
os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
|
||||
):
|
||||
adapters.append(adapter)
|
||||
return gr.update(value=[], choices=adapters, interactive=True)
|
||||
return gr.Dropdown(value=[], choices=adapters, interactive=True)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
|
@ -104,12 +104,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
|||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
return gr.Dropdown(value=[], choices=datasets)
|
||||
|
||||
|
||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))
|
||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
|
||||
return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))
|
||||
|
|
|
@ -7,7 +7,6 @@ from ..utils import check_json_schema
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
from ..engine import Engine
|
||||
|
@ -15,9 +14,9 @@ if TYPE_CHECKING:
|
|||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Column(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot(show_copy_button=True)
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
|
@ -33,14 +32,14 @@ def create_chat_box(
|
|||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
).then(lambda: "", outputs=[query])
|
||||
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
@ -22,24 +22,24 @@ def next_page(page_index: int, total_num: int) -> int:
|
|||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
|
||||
try:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
return gr.update(interactive=False)
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
return gr.Button(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
return gr.Button(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
|
@ -51,7 +51,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
|
|||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f] # noqa: C416
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
|
@ -67,7 +67,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
|||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
preview_samples = gr.JSON()
|
||||
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
|
@ -81,7 +81,7 @@ def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dic
|
|||
next_btn.click(next_page, [page_index, preview_count], [page_index], queue=False).then(
|
||||
get_preview, [dataset_dir, dataset, page_index], [preview_count, preview_samples, preview_box], queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
close_btn.click(lambda: gr.Column(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
|
|
|
@ -53,7 +53,7 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
output_elems = [output_box, process_bar]
|
||||
|
@ -68,9 +68,9 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
return elem_dict
|
||||
|
|
|
@ -74,7 +74,7 @@ def save_model(
|
|||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
|
||||
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100, step=1)
|
||||
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
|
||||
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
|
||||
export_legacy_format = gr.Checkbox()
|
||||
|
@ -89,12 +89,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.lang"),
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.model_path"),
|
||||
engine.manager.get_elem_by_name("top.adapter_path"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_name("top.template"),
|
||||
engine.manager.get_elem_by_id("top.lang"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.model_path"),
|
||||
engine.manager.get_elem_by_id("top.adapter_path"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.template"),
|
||||
max_shard_size,
|
||||
export_quantization_bit,
|
||||
export_quantization_dataset,
|
||||
|
|
|
@ -29,11 +29,11 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
|
||||
|
||||
return elem_dict
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
|
@ -25,7 +25,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
|||
adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||
|
@ -44,7 +44,7 @@ def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
|||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return lang, dict(
|
||||
return dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
|
|
@ -68,7 +68,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Extra config", open=False) as extra_tab:
|
||||
with gr.Accordion(open=False) as extra_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
|
@ -113,7 +113,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Accordion(open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(value="all", scale=3)
|
||||
|
@ -125,7 +125,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Accordion(open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=1, scale=1)
|
||||
|
@ -155,7 +155,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Accordion(open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
dpo_ftx = gr.Slider(value=0, minimum=0, maximum=10, step=0.01, scale=1)
|
||||
|
@ -163,7 +163,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then(
|
||||
list_adapters,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
@ -171,7 +171,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
||||
with gr.Accordion(open=False) as galore_tab:
|
||||
with gr.Row():
|
||||
use_galore = gr.Checkbox(scale=1)
|
||||
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
||||
|
@ -205,7 +205,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
resume_btn = gr.Checkbox(visible=False, interactive=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
|
@ -214,10 +214,10 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
|
||||
|
||||
elem_dict.update(
|
||||
dict(
|
||||
|
@ -235,8 +235,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||
output_box.change(
|
||||
gen_plot,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_id("top.model_name"),
|
||||
engine.manager.get_elem_by_id("top.finetuning_type"),
|
||||
output_dir,
|
||||
],
|
||||
loss_viewer,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Any, Dict, Generator
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
|
||||
from .chatter import WebChatModel
|
||||
|
@ -19,44 +18,45 @@ class Engine:
|
|||
self.runner = Runner(self.manager, demo_mode)
|
||||
self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
|
||||
r"""
|
||||
Gets the dict to update the components.
|
||||
"""
|
||||
output_dict: Dict["Component", "Component"] = {}
|
||||
for elem_id, elem_attr in input_dict.items():
|
||||
elem = self.manager.get_elem_by_id(elem_id)
|
||||
output_dict[elem] = elem.__class__(**elem_attr)
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
return output_dict
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Component], None, None]:
|
||||
user_config = load_config() if not self.demo_mode else {}
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
|
||||
|
||||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
init_dict["train.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset().choices}
|
||||
init_dict["train.output_dir"] = {"value": "train_" + get_time()}
|
||||
init_dict["eval.output_dir"] = {"value": "eval_" + get_time()}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
yield self._form_dict(init_dict)
|
||||
yield self._update_component(init_dict)
|
||||
|
||||
if not self.pure_chat:
|
||||
if self.runner.alive and not self.demo_mode:
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._form_dict({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
if self.runner.alive and not self.demo_mode and not self.pure_chat:
|
||||
yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._update_component({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict(
|
||||
{
|
||||
"train.output_dir": {"value": "train_" + get_time()},
|
||||
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||
}
|
||||
)
|
||||
yield self._update_component({"eval.resume_btn": {"value": True}})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
def change_lang(self, lang: str) -> Dict[Component, Component]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values()
|
||||
for name, component in elems.items()
|
||||
if name in LOCALES
|
||||
elem: elem.__class__(**LOCALES[elem_name][lang])
|
||||
for elem_name, elem in self.manager.get_elem_iter()
|
||||
if elem_name in LOCALES
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ from .css import CSS
|
|||
from .engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
|
||||
require_version("gradio>4.0.0,<=4.21.0", "To fix: pip install gradio==4.21.0")
|
||||
|
||||
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
|
@ -29,23 +29,24 @@ def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
|||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
lang, engine.manager.all_elems["top"] = create_top()
|
||||
engine.manager.add_elem_dict("top", create_top())
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
engine.manager.add_elem_dict("train", create_train_tab(engine))
|
||||
|
||||
with gr.Tab("Evaluate & Predict"):
|
||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||
engine.manager.add_elem_dict("eval", create_eval_tab(engine))
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||
engine.manager.add_elem_dict("infer", create_infer_tab(engine))
|
||||
|
||||
if not demo_mode:
|
||||
with gr.Tab("Export"):
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
engine.manager.add_elem_dict("export", create_export_tab(engine))
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
@ -56,19 +57,17 @@ def create_web_demo() -> gr.Blocks:
|
|||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
engine.manager.add_elem_dict("top", dict(lang=lang))
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||
engine.manager.add_elem_dict("infer", dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
|
||||
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -7,27 +7,49 @@ if TYPE_CHECKING:
|
|||
|
||||
class Manager:
|
||||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
self._elem_dicts: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
def get_elem_by_name(self, name: str) -> "Component":
|
||||
def add_elem_dict(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
|
||||
r"""
|
||||
Adds a elem dict.
|
||||
"""
|
||||
self._elem_dicts[tab_name] = elem_dict
|
||||
|
||||
def get_elem_list(self) -> List["Component"]:
|
||||
r"""
|
||||
Returns the list of all elements.
|
||||
"""
|
||||
return [elem for elem_dict in self._elem_dicts.values() for elem in elem_dict.values()]
|
||||
|
||||
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
|
||||
r"""
|
||||
Returns an iterator over all elements with their names.
|
||||
"""
|
||||
for elem_dict in self._elem_dicts.values():
|
||||
for elem_name, elem in elem_dict.items():
|
||||
yield elem_name, elem
|
||||
|
||||
def get_elem_by_id(self, elem_id: str) -> "Component":
|
||||
r"""
|
||||
Gets element by id.
|
||||
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = name.split(".")
|
||||
return self.all_elems[tab_name][elem_name]
|
||||
tab_name, elem_name = elem_id.split(".")
|
||||
return self._elem_dicts[tab_name][elem_name]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
r"""
|
||||
Gets the base elements that are commonly used.
|
||||
"""
|
||||
return {
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
self.all_elems["top"]["adapter_path"],
|
||||
self.all_elems["top"]["finetuning_type"],
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["rope_scaling"],
|
||||
self.all_elems["top"]["booster"],
|
||||
self._elem_dicts["top"]["lang"],
|
||||
self._elem_dicts["top"]["model_name"],
|
||||
self._elem_dicts["top"]["model_path"],
|
||||
self._elem_dicts["top"]["finetuning_type"],
|
||||
self._elem_dicts["top"]["adapter_path"],
|
||||
self._elem_dicts["top"]["quantization_bit"],
|
||||
self._elem_dicts["top"]["template"],
|
||||
self._elem_dicts["top"]["rope_scaling"],
|
||||
self._elem_dicts["top"]["booster"],
|
||||
}
|
||||
|
||||
def list_elems(self) -> List["Component"]:
|
||||
return [elem for elems in self.all_elems.values() for elem in elems.values()]
|
||||
|
|
|
@ -48,8 +48,8 @@ class Runner:
|
|||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
|
@ -95,8 +95,8 @@ class Runner:
|
|||
else:
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
|
@ -196,8 +196,8 @@ class Runner:
|
|||
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
|
@ -232,6 +232,7 @@ class Runner:
|
|||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if get("eval.predict"):
|
||||
args["do_predict"] = True
|
||||
|
@ -240,22 +241,20 @@ class Runner:
|
|||
|
||||
return args
|
||||
|
||||
def _preview(
|
||||
self, data: Dict[Component, Any], do_train: bool
|
||||
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
error = self._initialize(data, do_train, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
yield error, gr.Slider(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
yield gen_cmd(args), gr.Slider(visible=False)
|
||||
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
error = self._initialize(data, do_train, from_preview=False)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
yield error, gr.Slider(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
|
@ -264,20 +263,20 @@ class Runner:
|
|||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, gr.Slider], None, None]:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
|
||||
def monitor(self) -> Generator[Tuple[str, "gr.Slider"], None, None]:
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
self.running = True
|
||||
lang = get("top.lang")
|
||||
output_dir = get_save_dir(
|
||||
|
@ -286,13 +285,14 @@ class Runner:
|
|||
get("{}.output_dir".format("train" if self.do_train else "eval")),
|
||||
)
|
||||
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
while self.thread is not None and self.thread.is_alive():
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
yield ALERTS["info_aborting"][lang], gr.Slider(visible=False)
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
|
@ -304,4 +304,4 @@ class Runner:
|
|||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
yield self._finalize(lang, finish_info), gr.Slider(visible=False)
|
||||
|
|
|
@ -19,26 +19,26 @@ if is_matplotlib_available():
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||
def update_process_bar(callback: "LogCallback") -> "gr.Slider":
|
||||
if not callback.max_steps:
|
||||
return gr.update(visible=False)
|
||||
return gr.Slider(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps, callback.max_steps, callback.elapsed_time, callback.remaining_time
|
||||
)
|
||||
return gr.update(label=label, value=percentage, visible=True)
|
||||
return gr.Slider(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="None", interactive=False)
|
||||
return gr.Dropdown(value="None", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
return gr.Dropdown(interactive=True)
|
||||
|
||||
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
|
@ -48,8 +48,8 @@ def check_json_schema(text: str, lang: str) -> None:
|
|||
assert isinstance(tools, list)
|
||||
for tool in tools:
|
||||
if "name" not in tool:
|
||||
raise ValueError("Name not found.")
|
||||
except ValueError:
|
||||
raise NotImplementedError("Name not found.")
|
||||
except NotImplementedError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except Exception:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
|
|
@ -2,9 +2,7 @@ from llmtuner import create_ui
|
|||
|
||||
|
||||
def main():
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
create_ui().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,9 +2,7 @@ from llmtuner import create_web_demo
|
|||
|
||||
|
||||
def main():
|
||||
demo = create_web_demo()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
create_web_demo().queue().launch(server_name="0.0.0.0", server_port=None, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue