fix webui
This commit is contained in:
parent
9cde5e8af6
commit
34d8b2e56c
|
@ -26,9 +26,8 @@ class WebChatModel(ChatModel):
|
||||||
return self.model is not None
|
return self.model is not None
|
||||||
|
|
||||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
get = lambda name: data[self.manager.get_elem(name)]
|
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||||
lang = get("top.lang")
|
lang = get("top.lang")
|
||||||
|
|
||||||
if self.loaded:
|
if self.loaded:
|
||||||
yield ALERTS["err_exists"][lang]
|
yield ALERTS["err_exists"][lang]
|
||||||
return
|
return
|
||||||
|
@ -65,9 +64,7 @@ class WebChatModel(ChatModel):
|
||||||
yield ALERTS["info_loaded"][lang]
|
yield ALERTS["info_loaded"][lang]
|
||||||
|
|
||||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||||
get = lambda name: data[self.manager.get_elem(name)]
|
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||||
lang = get("top.lang")
|
|
||||||
|
|
||||||
yield ALERTS["info_unloading"][lang]
|
yield ALERTS["info_unloading"][lang]
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
|
@ -21,12 +21,12 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
export_btn.click(
|
export_btn.click(
|
||||||
save_model,
|
save_model,
|
||||||
[
|
[
|
||||||
engine.manager.get_elem("top.lang"),
|
engine.manager.get_elem_by_name("top.lang"),
|
||||||
engine.manager.get_elem("top.model_name"),
|
engine.manager.get_elem_by_name("top.model_name"),
|
||||||
engine.manager.get_elem("top.model_path"),
|
engine.manager.get_elem_by_name("top.model_path"),
|
||||||
engine.manager.get_elem("top.checkpoints"),
|
engine.manager.get_elem_by_name("top.checkpoints"),
|
||||||
engine.manager.get_elem("top.finetuning_type"),
|
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||||
engine.manager.get_elem("top.template"),
|
engine.manager.get_elem_by_name("top.template"),
|
||||||
max_shard_size,
|
max_shard_size,
|
||||||
export_dir
|
export_dir
|
||||||
],
|
],
|
||||||
|
|
|
@ -113,7 +113,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
|
||||||
refresh_btn.click(
|
refresh_btn.click(
|
||||||
list_checkpoint,
|
list_checkpoint,
|
||||||
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
|
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||||
[reward_model],
|
[reward_model],
|
||||||
queue=False
|
queue=False
|
||||||
)
|
)
|
||||||
|
@ -155,7 +155,11 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
|
|
||||||
output_box.change(
|
output_box.change(
|
||||||
gen_plot,
|
gen_plot,
|
||||||
[engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
|
[
|
||||||
|
engine.manager.get_elem_by_name("top.model_name"),
|
||||||
|
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||||
|
output_dir
|
||||||
|
],
|
||||||
loss_viewer,
|
loss_viewer,
|
||||||
queue=False
|
queue=False
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,7 @@ def create_ui() -> gr.Blocks:
|
||||||
|
|
||||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||||
engine.manager.all_elems["top"] = create_top()
|
engine.manager.all_elems["top"] = create_top()
|
||||||
lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
|
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||||
|
|
||||||
with gr.Tab("Train"):
|
with gr.Tab("Train"):
|
||||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||||
|
|
Loading…
Reference in New Issue