diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 4099fe56..7d96fb5f 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -35,6 +35,8 @@ IGNORE_INDEX = -100 LAYERNORM_NAMES = {"norm", "ln"} +LLAMABOARD_CONFIG = "llamaboard_config.yaml" + METHODS = ["full", "freeze", "lora"] MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} @@ -47,10 +49,10 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUPPORTED_MODELS = OrderedDict() -TRAINER_CONFIG = "trainer_config.yaml" - TRAINER_LOG = "trainer_log.jsonl" +TRAINING_ARGS = "training_args.yaml" + TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 1a77d613..d17873f7 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -50,7 +50,7 @@ def init_adapter( logger.info("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True - if finetuning_args.finetuning_type == "full" and is_trainable: + if is_trainable and finetuning_args.finetuning_type == "full": logger.info("Fine-tuning method: Full") forbidden_modules = set() @@ -67,7 +67,7 @@ def init_adapter( else: param.requires_grad_(False) - if finetuning_args.finetuning_type == "freeze" and is_trainable: + if is_trainable and finetuning_args.finetuning_type == "freeze": logger.info("Fine-tuning method: Freeze") if model_args.visual_inputs: diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 304b56a5..37b38df0 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -50,13 +50,6 @@ def get_config_path() -> os.PathLike: return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def get_arg_save_path(config_path: str) -> os.PathLike: - r""" - Gets the path to saved arguments. - """ - return os.path.join(DEFAULT_CONFIG_DIR, config_path) - - def load_config() -> Dict[str, Any]: r""" Loads user config if exists. @@ -77,24 +70,28 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona user_config["lang"] = lang or user_config["lang"] if model_name: user_config["last_model"] = model_name + + if model_name and model_path: user_config["path_dict"][model_name] = model_path + with open(get_config_path(), "w", encoding="utf-8") as f: safe_dump(user_config, f) -def get_model_path(model_name: str) -> Optional[str]: +def get_model_path(model_name: str) -> str: r""" Gets the model path according to the model name. """ user_config = load_config() - path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) - model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None) + path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) + model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "") if ( use_modelscope() and path_dict.get(DownloadSource.MODELSCOPE) and model_path == path_dict.get(DownloadSource.DEFAULT) ): # replace path model_path = path_dict.get(DownloadSource.MODELSCOPE) + return model_path diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index c794d0aa..fd0ead3d 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -36,7 +36,8 @@ def create_top() -> Dict[str, "Component"]: visual_inputs = gr.Checkbox(scale=1) model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False) - model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) + model_name.input(save_config, inputs=[lang, model_name], queue=False) + model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False) checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 74f8ef2a..72dfc858 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets -from ..utils import change_stage, check_output_dir, list_config_paths, list_output_dirs +from ..utils import change_stage, list_config_paths, list_output_dirs from .data import create_preview_box @@ -319,7 +319,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False) output_dir.change( list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None - ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None) + ) + output_dir.input( + engine.runner.check_output_dir, + [lang, model_name, finetuning_type, output_dir], + list(input_elems) + [output_box], + concurrency_limit=None, + ) config_path.change(list_config_paths, [current_time], [config_path], queue=False) return elem_dict diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index c046152c..35014628 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -5,11 +5,11 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME -from ..extras.constants import PEFT_METHODS, TRAINING_STAGES +from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import DEFAULT_CACHE_DIR, get_save_dir, load_config -from .locales import ALERTS +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config +from .locales import ALERTS, LOCALES from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd @@ -276,6 +276,10 @@ class Runner: else: self.do_train, self.running_data = do_train, data args = self._parse_train_args(data) if do_train else self._parse_eval_args(data) + + os.makedirs(args["output_dir"], exist_ok=True) + save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data)) + env = deepcopy(os.environ) env["LLAMABOARD_ENABLED"] = "1" if args.get("deepspeed", None) is not None: @@ -284,6 +288,16 @@ class Runner: self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True) yield from self.monitor() + def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]: + config_dict = {} + skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"] + for elem, value in data.items(): + elem_id = self.manager.get_id_by_elem(elem) + if elem_id not in skip_ids: + config_dict[elem_id] = value + + return config_dict + def preview_train(self, data): yield from self._preview(data, do_train=True) @@ -349,28 +363,24 @@ class Runner: } yield return_dict - def save_args(self, data: dict): + def save_args(self, data): output_box = self.manager.get_elem_by_id("train.output_box") error = self._initialize(data, do_train=True, from_preview=True) if error: gr.Warning(error) return {output_box: error} - config_dict: Dict[str, Any] = {} lang = data[self.manager.get_elem_by_id("top.lang")] config_path = data[self.manager.get_elem_by_id("train.config_path")] - skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"] - for elem, value in data.items(): - elem_id = self.manager.get_id_by_elem(elem) - if elem_id not in skip_ids: - config_dict[elem_id] = value + os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) + save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path) - save_path = save_args(config_path, config_dict) + save_args(save_path, self._form_config_dict(data)) return {output_box: ALERTS["info_config_saved"][lang] + save_path} def load_args(self, lang: str, config_path: str): output_box = self.manager.get_elem_by_id("train.output_box") - config_dict = load_args(config_path) + config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path)) if config_dict is None: gr.Warning(ALERTS["err_config_not_found"][lang]) return {output_box: ALERTS["err_config_not_found"][lang]} @@ -380,3 +390,17 @@ class Runner: output_dict[self.manager.get_elem_by_id(elem_id)] = value return output_dict + + def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str): + output_box = self.manager.get_elem_by_id("train.output_box") + output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]} + if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): + gr.Warning(ALERTS["warn_output_dir_exists"][lang]) + output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang] + + output_dir = get_save_dir(model_name, finetuning_type, output_dir) + config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config + for elem_id, value in config_dict.items(): + output_dict[self.manager.get_elem_by_id(elem_id)] = value + + return output_dict diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 23e62dca..e39f2aa4 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -8,10 +8,10 @@ import psutil from transformers.trainer_utils import get_last_checkpoint from yaml import safe_dump, safe_load -from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES +from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot -from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir from .locales import ALERTS @@ -93,10 +93,10 @@ def save_cmd(args: Dict[str, Any]) -> str: output_dir = args["output_dir"] os.makedirs(output_dir, exist_ok=True) - with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f: + with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f: safe_dump(clean_cmd(args), f) - return os.path.join(output_dir, TRAINER_CONFIG) + return os.path.join(output_dir, TRAINING_ARGS) def get_eval_results(path: os.PathLike) -> str: @@ -157,22 +157,19 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]: Loads saved arguments. """ try: - with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f: + with open(config_path, "r", encoding="utf-8") as f: return safe_load(f) except Exception: return None -def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: +def save_args(config_path: str, config_dict: Dict[str, Any]): r""" Saves arguments. """ - os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) - with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f: + with open(config_path, "w", encoding="utf-8") as f: safe_dump(config_dict, f) - return str(get_arg_save_path(config_path)) - def list_config_paths(current_time: str) -> "gr.Dropdown": r""" @@ -181,13 +178,13 @@ def list_config_paths(current_time: str) -> "gr.Dropdown": config_files = ["{}.yaml".format(current_time)] if os.path.isdir(DEFAULT_CONFIG_DIR): for file_name in os.listdir(DEFAULT_CONFIG_DIR): - if file_name.endswith(".yaml"): + if file_name.endswith(".yaml") and file_name not in config_files: config_files.append(file_name) return gr.Dropdown(choices=config_files) -def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown": +def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": r""" Lists all the directories that can resume from. """ @@ -203,14 +200,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) - return gr.Dropdown(choices=output_dirs) -def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None: - r""" - Check if output dir exists. - """ - if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): - gr.Warning(ALERTS["warn_output_dir_exists"][lang]) - - def create_ds_config() -> None: r""" Creates deepspeed config.