fix webui

This commit is contained in:
hiyouga 2023-11-02 18:03:14 +08:00
parent 9cde5e8af6
commit 34d8b2e56c
4 changed files with 15 additions and 14 deletions

View File

@ -26,9 +26,8 @@ class WebChatModel(ChatModel):
return self.model is not None return self.model is not None
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)] get = lambda name: data[self.manager.get_elem_by_name(name)]
lang = get("top.lang") lang = get("top.lang")
if self.loaded: if self.loaded:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
return return
@ -65,9 +64,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
get = lambda name: data[self.manager.get_elem(name)] lang = data[self.manager.get_elem_by_name("top.lang")]
lang = get("top.lang")
yield ALERTS["info_unloading"][lang] yield ALERTS["info_unloading"][lang]
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None

View File

@ -21,12 +21,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
export_btn.click( export_btn.click(
save_model, save_model,
[ [
engine.manager.get_elem("top.lang"), engine.manager.get_elem_by_name("top.lang"),
engine.manager.get_elem("top.model_name"), engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem("top.model_path"), engine.manager.get_elem_by_name("top.model_path"),
engine.manager.get_elem("top.checkpoints"), engine.manager.get_elem_by_name("top.checkpoints"),
engine.manager.get_elem("top.finetuning_type"), engine.manager.get_elem_by_name("top.finetuning_type"),
engine.manager.get_elem("top.template"), engine.manager.get_elem_by_name("top.template"),
max_shard_size, max_shard_size,
export_dir export_dir
], ],

View File

@ -113,7 +113,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
refresh_btn.click( refresh_btn.click(
list_checkpoint, list_checkpoint,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")], [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model], [reward_model],
queue=False queue=False
) )
@ -155,7 +155,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_box.change( output_box.change(
gen_plot, gen_plot,
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir], [
engine.manager.get_elem_by_name("top.model_name"),
engine.manager.get_elem_by_name("top.finetuning_type"),
output_dir
],
loss_viewer, loss_viewer,
queue=False queue=False
) )

View File

@ -22,7 +22,7 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="LLaMA Board", css=CSS) as demo: with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
engine.manager.all_elems["top"] = create_top() engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine) engine.manager.all_elems["train"] = create_train_tab(engine)