fix bugs in webui
This commit is contained in:
parent
7ed1fa6fe9
commit
fde05cacfc
|
@ -37,12 +37,6 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]:
|
|||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang, model_name, model_path]
|
||||
)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
).then(
|
||||
|
|
|
@ -12,22 +12,24 @@ from llmtuner.webui.utils import get_time
|
|||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, init_chat: Optional[bool] = False) -> None:
|
||||
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
||||
self.pure_chat = pure_chat
|
||||
self.manager: "Manager" = Manager()
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat))
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
lang = config.get("lang", None) or "en"
|
||||
|
||||
resume_dict = {
|
||||
"top.config": {"value": config},
|
||||
"top.lang": {"value": lang},
|
||||
"train.dataset": {"choices": list_dataset()["choices"]},
|
||||
"eval.dataset": {"choices": list_dataset()["choices"]},
|
||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||
}
|
||||
|
||||
if not self.pure_chat:
|
||||
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
|
||||
if config.get("last_model", None):
|
||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||
|
@ -42,5 +44,5 @@ class Engine:
|
|||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items()
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
||||
}
|
||||
|
|
|
@ -18,10 +18,14 @@ require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
|||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
engine = Engine(init_chat=False)
|
||||
engine = Engine(pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
engine.manager.all_elems["top"] = create_top(engine)
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
||||
config = engine.manager.get_elem("top.config")
|
||||
model_name = engine.manager.get_elem("top.model_name")
|
||||
model_path = engine.manager.get_elem("top.model_path")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
@ -35,29 +39,37 @@ def create_ui() -> gr.Blocks:
|
|||
with gr.Tab("Export"):
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang, model_name, model_path]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def create_web_demo() -> gr.Blocks:
|
||||
engine = Engine(init_chat=True)
|
||||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
config = gr.State(value=load_config())
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
|
||||
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
lang.change(
|
||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||
).then(
|
||||
save_config, inputs=[config, lang]
|
||||
)
|
||||
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
|
||||
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
|
||||
|
||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
LOCALES = {
|
||||
"config": {
|
||||
"en": {},
|
||||
"zh": {}
|
||||
},
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Lang"
|
||||
|
@ -447,10 +443,6 @@ LOCALES = {
|
|||
"label": "保存预测结果"
|
||||
}
|
||||
},
|
||||
"chat_box": {
|
||||
"en": {},
|
||||
"zh": {}
|
||||
},
|
||||
"load_btn": {
|
||||
"en": {
|
||||
"value": "Load model"
|
||||
|
|
|
@ -199,10 +199,10 @@ class Runner:
|
|||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.prepare(data, self._parse_train_args)
|
||||
yield from self.prepare(data, self._parse_train_args)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
self.prepare(data, self._parse_eval_args)
|
||||
yield from self.prepare(data, self._parse_eval_args)
|
||||
|
||||
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
||||
|
@ -213,9 +213,9 @@ class Runner:
|
|||
else:
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
yield self.monitor(lang, output_dir, is_training)
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor(lang, output_dir, is_training)
|
||||
|
||||
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
while self.thread.is_alive():
|
||||
|
|
Loading…
Reference in New Issue