fix bugs in webui
This commit is contained in:
parent
7ed1fa6fe9
commit
fde05cacfc
|
@ -37,12 +37,6 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
shift_attn = gr.Checkbox(value=False)
|
shift_attn = gr.Checkbox(value=False)
|
||||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
|
||||||
|
|
||||||
lang.change(
|
|
||||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
|
||||||
).then(
|
|
||||||
save_config, inputs=[config, lang, model_name, model_path]
|
|
||||||
)
|
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
|
|
|
@ -12,25 +12,27 @@ from llmtuner.webui.utils import get_time
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
|
|
||||||
def __init__(self, init_chat: Optional[bool] = False) -> None:
|
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
||||||
|
self.pure_chat = pure_chat
|
||||||
self.manager: "Manager" = Manager()
|
self.manager: "Manager" = Manager()
|
||||||
self.runner: "Runner" = Runner(self.manager)
|
self.runner: "Runner" = Runner(self.manager)
|
||||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat))
|
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||||
|
|
||||||
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||||
lang = config.get("lang", None) or "en"
|
lang = config.get("lang", None) or "en"
|
||||||
|
|
||||||
resume_dict = {
|
resume_dict = {
|
||||||
"top.config": {"value": config},
|
|
||||||
"top.lang": {"value": lang},
|
"top.lang": {"value": lang},
|
||||||
"train.dataset": {"choices": list_dataset()["choices"]},
|
|
||||||
"eval.dataset": {"choices": list_dataset()["choices"]},
|
|
||||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.get("last_model", None):
|
if not self.pure_chat:
|
||||||
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||||
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||||
|
|
||||||
|
if config.get("last_model", None):
|
||||||
|
resume_dict["top.model_name"] = {"value": config["last_model"]}
|
||||||
|
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
|
||||||
|
|
||||||
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||||
|
|
||||||
|
@ -42,5 +44,5 @@ class Engine:
|
||||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
component: gr.update(**LOCALES[name][lang])
|
component: gr.update(**LOCALES[name][lang])
|
||||||
for elems in self.manager.all_elems.values() for name, component in elems.items()
|
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,10 +18,14 @@ require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||||
|
|
||||||
|
|
||||||
def create_ui() -> gr.Blocks:
|
def create_ui() -> gr.Blocks:
|
||||||
engine = Engine(init_chat=False)
|
engine = Engine(pure_chat=False)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||||
engine.manager.all_elems["top"] = create_top(engine)
|
engine.manager.all_elems["top"] = create_top(engine)
|
||||||
|
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
||||||
|
config = engine.manager.get_elem("top.config")
|
||||||
|
model_name = engine.manager.get_elem("top.model_name")
|
||||||
|
model_path = engine.manager.get_elem("top.model_path")
|
||||||
|
|
||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||||
|
@ -35,29 +39,37 @@ def create_ui() -> gr.Blocks:
|
||||||
with gr.Tab("Export"):
|
with gr.Tab("Export"):
|
||||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||||
|
|
||||||
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems())
|
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||||
|
|
||||||
|
lang.change(
|
||||||
|
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||||
|
).then(
|
||||||
|
save_config, inputs=[config, lang, model_name, model_path]
|
||||||
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
def create_web_demo() -> gr.Blocks:
|
def create_web_demo() -> gr.Blocks:
|
||||||
engine = Engine(init_chat=True)
|
engine = Engine(pure_chat=True)
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
lang = gr.Dropdown(choices=["en", "zh"])
|
|
||||||
config = gr.State(value=load_config())
|
config = gr.State(value=load_config())
|
||||||
|
lang = gr.Dropdown(choices=["en", "zh"])
|
||||||
|
|
||||||
|
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
|
||||||
|
|
||||||
|
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||||
|
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||||
|
|
||||||
|
demo.load(engine.resume, [config], engine.manager.list_elems())
|
||||||
|
|
||||||
lang.change(
|
lang.change(
|
||||||
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
|
||||||
).then(
|
).then(
|
||||||
save_config, inputs=[config, lang]
|
save_config, inputs=[config, lang]
|
||||||
)
|
)
|
||||||
|
|
||||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
|
||||||
|
|
||||||
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
|
|
||||||
|
|
||||||
demo.load(engine.resume, [config], engine.manager.list_elems())
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,4 @@
|
||||||
LOCALES = {
|
LOCALES = {
|
||||||
"config": {
|
|
||||||
"en": {},
|
|
||||||
"zh": {}
|
|
||||||
},
|
|
||||||
"lang": {
|
"lang": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Lang"
|
"label": "Lang"
|
||||||
|
@ -447,10 +443,6 @@ LOCALES = {
|
||||||
"label": "保存预测结果"
|
"label": "保存预测结果"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"chat_box": {
|
|
||||||
"en": {},
|
|
||||||
"zh": {}
|
|
||||||
},
|
|
||||||
"load_btn": {
|
"load_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Load model"
|
"value": "Load model"
|
||||||
|
|
|
@ -199,10 +199,10 @@ class Runner:
|
||||||
yield gen_cmd(args), gr.update(visible=False)
|
yield gen_cmd(args), gr.update(visible=False)
|
||||||
|
|
||||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
self.prepare(data, self._parse_train_args)
|
yield from self.prepare(data, self._parse_train_args)
|
||||||
|
|
||||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
self.prepare(data, self._parse_eval_args)
|
yield from self.prepare(data, self._parse_eval_args)
|
||||||
|
|
||||||
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
parse_func = self._parse_train_args if is_training else self._parse_eval_args
|
||||||
|
@ -213,9 +213,9 @@ class Runner:
|
||||||
else:
|
else:
|
||||||
self.running = True
|
self.running = True
|
||||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
thread = Thread(target=run_exp, kwargs=run_kwargs)
|
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
thread.start()
|
self.thread.start()
|
||||||
yield self.monitor(lang, output_dir, is_training)
|
yield from self.monitor(lang, output_dir, is_training)
|
||||||
|
|
||||||
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
|
|
Loading…
Reference in New Issue