support warning in webui

This commit is contained in:
hiyouga 2023-11-02 17:57:04 +08:00
parent f8703aac08
commit 9cde5e8af6
4 changed files with 39 additions and 32 deletions

View File

@ -73,14 +73,15 @@ def get_template(model_name: str) -> str:
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
):
checkpoints.append(checkpoint)
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)

View File

@ -19,7 +19,7 @@ class Engine:
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
user_config = load_config()
@ -42,7 +42,7 @@ class Engine:
if not self.pure_chat:
if self.runner.alive:
yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
if self.runner.do_train:
yield self._form_dict({"train.resume_btn": {"value": True}})
else:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, List, Set
if TYPE_CHECKING:
from gradio.components import Component
@ -9,14 +9,14 @@ class Manager:
def __init__(self) -> None:
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
def get_elem(self, name: str) -> "Component":
def get_elem_by_name(self, name: str) -> "Component":
r"""
Example: top.lang, train.dataset
"""
tab_name, elem_name = name.split(".")
return self.all_elems[tab_name][elem_name]
def get_base_elems(self):
def get_base_elems(self) -> Set["Component"]:
return {
self.all_elems["top"]["lang"],
self.all_elems["top"]["model_name"],

View File

@ -26,12 +26,15 @@ class Runner:
def __init__(self, manager: "Manager") -> None:
self.manager = manager
""" Resume """
self.thread: "Thread" = None
self.data: Dict["Component", Any] = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.monitor_inputs: Dict[str, str] = None
""" State """
self.aborted = False
self.running = False
""" Handler """
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
@ -45,7 +48,11 @@ class Runner:
self.aborted = True
self.running = False
def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
get = lambda name: data[self.manager.get_elem_by_name(name)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
if self.running:
return ALERTS["err_conflict"][lang]
@ -72,8 +79,8 @@ class Runner:
else:
return finish_info
def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
@ -83,8 +90,6 @@ class Runner:
else:
checkpoint_dir = None
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
model_name_or_path=get("top.model_path"),
@ -119,7 +124,7 @@ class Runner:
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") if get("train.additional_target") else None,
resume_lora_training=get("train.resume_lora_training"),
output_dir=output_dir
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
)
args[get("train.compute_type")] = True
args["disable_tqdm"] = True
@ -142,10 +147,10 @@ class Runner:
args["eval_steps"] = get("train.save_steps")
args["load_best_model_at_end"] = True
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
return args
def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
get = lambda name: data[self.manager.get_elem(name)]
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
@ -188,27 +193,28 @@ class Runner:
args.pop("do_eval", None)
args["do_predict"] = True
return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, _, args = parse_func(data)
error = self._initialize(lang, model_name, model_path, dataset)
error = self._initialize(data, do_train)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield gen_cmd(args), gr.update(visible=False)
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
parse_func = self._parse_train_args if do_train else self._parse_eval_args
lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
error = self._initialize(lang, model_name, model_path, dataset)
error = self._initialize(data, do_train)
if error:
gr.Warning(error)
yield error, gr.update(visible=False)
else:
self.running = True
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
self.running = True
self.do_train, self.running_data = do_train, data
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
self.thread.start()
yield from self.monitor()