fix bugs in webui

This commit is contained in:
hiyouga 2023-10-15 03:41:58 +08:00
parent 7ed1fa6fe9
commit fde05cacfc
5 changed files with 38 additions and 38 deletions

View File

@ -37,12 +37,6 @@ def create_top(engine: "Engine") -> Dict[str, "Component"]:
shift_attn = gr.Checkbox(value=False) shift_attn = gr.Checkbox(value=False)
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic"], value="none")
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(

View File

@ -12,25 +12,27 @@ from llmtuner.webui.utils import get_time
class Engine: class Engine:
def __init__(self, init_chat: Optional[bool] = False) -> None: def __init__(self, pure_chat: Optional[bool] = False) -> None:
self.pure_chat = pure_chat
self.manager: "Manager" = Manager() self.manager: "Manager" = Manager()
self.runner: "Runner" = Runner(self.manager) self.runner: "Runner" = Runner(self.manager)
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not init_chat)) self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]: def resume(self, config: CONFIG_CLASS) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
lang = config.get("lang", None) or "en" lang = config.get("lang", None) or "en"
resume_dict = { resume_dict = {
"top.config": {"value": config},
"top.lang": {"value": lang}, "top.lang": {"value": lang},
"train.dataset": {"choices": list_dataset()["choices"]},
"eval.dataset": {"choices": list_dataset()["choices"]},
"infer.chat_box": {"visible": self.chatter.loaded} "infer.chat_box": {"visible": self.chatter.loaded}
} }
if config.get("last_model", None): if not self.pure_chat:
resume_dict["top.model_name"] = {"value": config["last_model"]} resume_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])} resume_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
if config.get("last_model", None):
resume_dict["top.model_name"] = {"value": config["last_model"]}
resume_dict["top.model_path"] = {"value": get_model_path(config, config["last_model"])}
yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} yield {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
@ -42,5 +44,5 @@ class Engine:
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
return { return {
component: gr.update(**LOCALES[name][lang]) component: gr.update(**LOCALES[name][lang])
for elems in self.manager.all_elems.values() for name, component in elems.items() for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
} }

View File

@ -18,10 +18,14 @@ require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def create_ui() -> gr.Blocks: def create_ui() -> gr.Blocks:
engine = Engine(init_chat=False) engine = Engine(pure_chat=False)
with gr.Blocks(title="Web Tuner", css=CSS) as demo: with gr.Blocks(title="Web Tuner", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top(engine) engine.manager.all_elems["top"] = create_top(engine)
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
config = engine.manager.get_elem("top.config")
model_name = engine.manager.get_elem("top.model_name")
model_path = engine.manager.get_elem("top.model_path")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine) engine.manager.all_elems["train"] = create_train_tab(engine)
@ -35,29 +39,37 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"): with gr.Tab("Export"):
engine.manager.all_elems["export"] = create_export_tab(engine) engine.manager.all_elems["export"] = create_export_tab(engine)
demo.load(engine.resume, [engine.manager.get_elem("top.config")], engine.manager.list_elems()) demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then(
save_config, inputs=[config, lang, model_name, model_path]
)
return demo return demo
def create_web_demo() -> gr.Blocks: def create_web_demo() -> gr.Blocks:
engine = Engine(init_chat=True) engine = Engine(pure_chat=True)
with gr.Blocks(title="Web Demo", css=CSS) as demo: with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"])
config = gr.State(value=load_config()) config = gr.State(value=load_config())
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.all_elems["top"] = dict(config=config, lang=lang)
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
demo.load(engine.resume, [config], engine.manager.list_elems())
lang.change( lang.change(
engine.change_lang, [lang], engine.manager.list_elems(), queue=False engine.change_lang, [lang], engine.manager.list_elems(), queue=False
).then( ).then(
save_config, inputs=[config, lang] save_config, inputs=[config, lang]
) )
engine.manager.all_elems["top"] = dict(lang=lang)
_, _, _, engine.manager.all_elems["infer"] = create_chat_box(engine, visible=True)
demo.load(engine.resume, [config], engine.manager.list_elems())
return demo return demo

View File

@ -1,8 +1,4 @@
LOCALES = { LOCALES = {
"config": {
"en": {},
"zh": {}
},
"lang": { "lang": {
"en": { "en": {
"label": "Lang" "label": "Lang"
@ -447,10 +443,6 @@ LOCALES = {
"label": "保存预测结果" "label": "保存预测结果"
} }
}, },
"chat_box": {
"en": {},
"zh": {}
},
"load_btn": { "load_btn": {
"en": { "en": {
"value": "Load model" "value": "Load model"

View File

@ -199,10 +199,10 @@ class Runner:
yield gen_cmd(args), gr.update(visible=False) yield gen_cmd(args), gr.update(visible=False)
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.prepare(data, self._parse_train_args) yield from self.prepare(data, self._parse_train_args)
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
self.prepare(data, self._parse_eval_args) yield from self.prepare(data, self._parse_eval_args)
def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def prepare(self, data: Dict[Component, Any], is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if is_training else self._parse_eval_args parse_func = self._parse_train_args if is_training else self._parse_eval_args
@ -213,9 +213,9 @@ class Runner:
else: else:
self.running = True self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = Thread(target=run_exp, kwargs=run_kwargs) self.thread = Thread(target=run_exp, kwargs=run_kwargs)
thread.start() self.thread.start()
yield self.monitor(lang, output_dir, is_training) yield from self.monitor(lang, output_dir, is_training)
def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: def monitor(self, lang: str, output_dir: str, is_training: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
while self.thread.is_alive(): while self.thread.is_alive():