support warning in webui
This commit is contained in:
parent
f8703aac08
commit
9cde5e8af6
|
@ -73,14 +73,15 @@ def get_template(model_name: str) -> str:
|
|||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
if model_name:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ class Engine:
|
|||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
user_config = load_config()
|
||||
|
@ -42,7 +42,7 @@ class Engine:
|
|||
|
||||
if not self.pure_chat:
|
||||
if self.runner.alive:
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._form_dict({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import TYPE_CHECKING, Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
@ -9,14 +9,14 @@ class Manager:
|
|||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
def get_elem(self, name: str) -> "Component":
|
||||
def get_elem_by_name(self, name: str) -> "Component":
|
||||
r"""
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = name.split(".")
|
||||
return self.all_elems[tab_name][elem_name]
|
||||
|
||||
def get_base_elems(self):
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
return {
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
|
|
|
@ -26,12 +26,15 @@ class Runner:
|
|||
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.data: Dict["Component", Any] = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
|
@ -45,7 +48,11 @@ class Runner:
|
|||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
if self.running:
|
||||
return ALERTS["err_conflict"][lang]
|
||||
|
||||
|
@ -72,8 +79,8 @@ class Runner:
|
|||
else:
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
|
@ -83,8 +90,6 @@ class Runner:
|
|||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
model_name_or_path=get("top.model_path"),
|
||||
|
@ -119,7 +124,7 @@ class Runner:
|
|||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
||||
resume_lora_training=get("train.resume_lora_training"),
|
||||
output_dir=output_dir
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
)
|
||||
args[get("train.compute_type")] = True
|
||||
args["disable_tqdm"] = True
|
||||
|
@ -142,10 +147,10 @@ class Runner:
|
|||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
|
||||
return args
|
||||
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
|
||||
get = lambda name: data[self.manager.get_elem(name)]
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
|
@ -188,27 +193,28 @@ class Runner:
|
|||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||
lang, model_name, model_path, dataset, _, args = parse_func(data)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
parse_func = self._parse_train_args if do_train else self._parse_eval_args
|
||||
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
|
||||
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
|
||||
error = self._initialize(lang, model_name, model_path, dataset)
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
self.running = True
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.running = True
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
|
Loading…
Reference in New Issue