From 34d8b2e56c36aea2d8d2ecf1f5e89f8cdfb776a3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 2 Nov 2023 18:03:14 +0800 Subject: [PATCH] fix webui --- src/llmtuner/webui/chatter.py | 7 ++----- src/llmtuner/webui/components/export.py | 12 ++++++------ src/llmtuner/webui/components/train.py | 8 ++++++-- src/llmtuner/webui/interface.py | 2 +- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 712f2c75..40f04d18 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -26,9 +26,8 @@ class WebChatModel(ChatModel): return self.model is not None def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: - get = lambda name: data[self.manager.get_elem(name)] + get = lambda name: data[self.manager.get_elem_by_name(name)] lang = get("top.lang") - if self.loaded: yield ALERTS["err_exists"][lang] return @@ -65,9 +64,7 @@ class WebChatModel(ChatModel): yield ALERTS["info_loaded"][lang] def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: - get = lambda name: data[self.manager.get_elem(name)] - lang = get("top.lang") - + lang = data[self.manager.get_elem_by_name("top.lang")] yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index bfdc5dc8..75493d4a 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -21,12 +21,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: export_btn.click( save_model, [ - engine.manager.get_elem("top.lang"), - engine.manager.get_elem("top.model_name"), - engine.manager.get_elem("top.model_path"), - engine.manager.get_elem("top.checkpoints"), - engine.manager.get_elem("top.finetuning_type"), - engine.manager.get_elem("top.template"), + engine.manager.get_elem_by_name("top.lang"), + engine.manager.get_elem_by_name("top.model_name"), + engine.manager.get_elem_by_name("top.model_path"), + engine.manager.get_elem_by_name("top.checkpoints"), + engine.manager.get_elem_by_name("top.finetuning_type"), + engine.manager.get_elem_by_name("top.template"), max_shard_size, export_dir ], diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 20974d18..5c45268d 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -113,7 +113,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: refresh_btn.click( list_checkpoint, - [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")], + [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], [reward_model], queue=False ) @@ -155,7 +155,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: output_box.change( gen_plot, - [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir], + [ + engine.manager.get_elem_by_name("top.model_name"), + engine.manager.get_elem_by_name("top.finetuning_type"), + output_dir + ], loss_viewer, queue=False ) diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index b9292891..f9dac510 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -22,7 +22,7 @@ def create_ui() -> gr.Blocks: with gr.Blocks(title="LLaMA Board", css=CSS) as demo: engine.manager.all_elems["top"] = create_top() - lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") + lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang") with gr.Tab("Train"): engine.manager.all_elems["train"] = create_train_tab(engine)