From 80708717329b4552920dd4ce8cebc683e65d54c5 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 29 May 2024 23:55:38 +0800 Subject: [PATCH] better llamaboard * easily resume from checkpoint * support full and freeze checkpoints * faster ui --- src/llamafactory/__init__.py | 2 +- src/llamafactory/extras/constants.py | 21 +++- src/llamafactory/hparams/parser.py | 13 ++- src/llamafactory/webui/chatter.py | 33 +++--- src/llamafactory/webui/common.py | 109 +++++++++++++------- src/llamafactory/webui/components/eval.py | 4 +- src/llamafactory/webui/components/export.py | 29 +++--- src/llamafactory/webui/components/top.py | 23 ++--- src/llamafactory/webui/components/train.py | 33 +++--- src/llamafactory/webui/engine.py | 17 ++- src/llamafactory/webui/locales.py | 24 ++--- src/llamafactory/webui/manager.py | 2 +- src/llamafactory/webui/runner.py | 79 +++++++------- src/llamafactory/webui/utils.py | 107 ++++++++++++++++--- 14 files changed, 303 insertions(+), 193 deletions(-) diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index b889e268..78230937 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -1,4 +1,4 @@ -# Level: api, webui > chat, eval, train > data, model > extras, hparams +# Level: api, webui > chat, eval, train > data, model > hparams > extras from .cli import VERSION diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 5e2ee3ce..f365016f 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -2,6 +2,19 @@ from collections import OrderedDict, defaultdict from enum import Enum from typing import Dict, Optional +from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME +from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME + + +CHECKPOINT_NAMES = { + SAFE_ADAPTER_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +} CHOICES = ["A", "B", "C", "D"] @@ -26,9 +39,9 @@ LAYERNORM_NAMES = {"norm", "ln"} METHODS = ["full", "freeze", "lora"] -MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"] +MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} -PEFT_METHODS = ["lora"] +PEFT_METHODS = {"lora"} RUNNING_LOG = "running_log.txt" @@ -49,9 +62,9 @@ TRAINING_STAGES = { "Pre-Training": "pt", } -STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"] +STAGES_USE_PAIR_DATA = {"rm", "dpo"} -SUPPORTED_CLASS_FOR_S2ATTN = ["llama"] +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} V_HEAD_WEIGHTS_NAME = "value_head.bin" diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index adb206f8..b3c673be 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -11,6 +11,7 @@ from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_torch_bf16_gpu_available from transformers.utils.versions import require_version +from ..extras.constants import CHECKPOINT_NAMES from ..extras.logging import get_logger from ..extras.misc import check_dependencies, get_current_device from .data_args import DataArguments @@ -255,13 +256,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and can_resume_from_checkpoint ): last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and any( + os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES + ): + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + if last_checkpoint is not None: training_args.resume_from_checkpoint = last_checkpoint - logger.info( - "Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( - training_args.resume_from_checkpoint - ) - ) + logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) + logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") if ( finetuning_args.stage in ["rm", "ppo"] diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index a92f6ef7..c82710d3 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -6,6 +6,7 @@ from numpy.typing import NDArray from ..chat import ChatModel from ..data import Role +from ..extras.constants import PEFT_METHODS from ..extras.misc import torch_gc from ..extras.packages import is_gradio_available from .common import get_save_dir @@ -44,13 +45,14 @@ class WebChatModel(ChatModel): def load_model(self, data) -> Generator[str, None, None]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] - lang = get("top.lang") + lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path") + finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path") error = "" if self.loaded: error = ALERTS["err_exists"][lang] - elif not get("top.model_name"): + elif not model_name: error = ALERTS["err_no_model"][lang] - elif not get("top.model_path"): + elif not model_path: error = ALERTS["err_no_path"][lang] elif self.demo_mode: error = ALERTS["err_demo"][lang] @@ -60,21 +62,10 @@ class WebChatModel(ChatModel): yield error return - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - yield ALERTS["info_loading"][lang] args = dict( - model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, - finetuning_type=get("top.finetuning_type"), + model_name_or_path=model_path, + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", @@ -83,8 +74,16 @@ class WebChatModel(ChatModel): rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), ) - super().__init__(args) + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + + super().__init__(args) yield ALERTS["info_loaded"][lang] def unload_model(self, data) -> Generator[str, None, None]: diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index ea82fd88..62004bce 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -1,12 +1,12 @@ import json import os from collections import defaultdict -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple -from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from yaml import safe_dump, safe_load from ..extras.constants import ( + CHECKPOINT_NAMES, DATA_CONFIG, DEFAULT_MODULE, DEFAULT_TEMPLATE, @@ -29,7 +29,6 @@ if is_gradio_available(): logger = get_logger(__name__) -ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} DEFAULT_CACHE_DIR = "cache" DEFAULT_CONFIG_DIR = "config" DEFAULT_DATA_DIR = "data" @@ -38,19 +37,31 @@ USER_CONFIG = "user_config.yaml" def get_save_dir(*paths: str) -> os.PathLike: + r""" + Gets the path to saved model checkpoints. + """ paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths) return os.path.join(DEFAULT_SAVE_DIR, *paths) def get_config_path() -> os.PathLike: + r""" + Gets the path to user config. + """ return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def get_save_path(config_path: str) -> os.PathLike: +def get_arg_save_path(config_path: str) -> os.PathLike: + r""" + Gets the path to saved arguments. + """ return os.path.join(DEFAULT_CONFIG_DIR, config_path) def load_config() -> Dict[str, Any]: + r""" + Loads user config if exists. + """ try: with open(get_config_path(), "r", encoding="utf-8") as f: return safe_load(f) @@ -59,6 +70,9 @@ def load_config() -> Dict[str, Any]: def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: + r""" + Saves user config. + """ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) user_config = load_config() user_config["lang"] = lang or user_config["lang"] @@ -69,23 +83,10 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona safe_dump(user_config, f) -def load_args(config_path: str) -> Optional[Dict[str, Any]]: - try: - with open(get_save_path(config_path), "r", encoding="utf-8") as f: - return safe_load(f) - except Exception: - return None - - -def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: - os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) - with open(get_save_path(config_path), "w", encoding="utf-8") as f: - safe_dump(config_dict, f) - - return str(get_save_path(config_path)) - - -def get_model_path(model_name: str) -> str: +def get_model_path(model_name: str) -> Optional[str]: + r""" + Gets the model path according to the model name. + """ user_config = load_config() path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str)) model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None) @@ -99,40 +100,71 @@ def get_model_path(model_name: str) -> str: def get_prefix(model_name: str) -> str: + r""" + Gets the prefix of the model name to obtain the model family. + """ return model_name.split("-")[0] +def get_model_info(model_name: str) -> Tuple[str, str, bool]: + r""" + Gets the necessary information of this model. + + Returns: + model_path (str) + template (str) + visual (bool) + """ + return get_model_path(model_name), get_template(model_name), get_visual(model_name) + + def get_module(model_name: str) -> str: - return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj") + r""" + Gets the LoRA modules of this model. + """ + return DEFAULT_MODULE.get(get_prefix(model_name), "all") def get_template(model_name: str) -> str: + r""" + Gets the template name if the model is a chat model. + """ if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: return DEFAULT_TEMPLATE[get_prefix(model_name)] return "default" def get_visual(model_name: str) -> bool: + r""" + Judges if the model is a vision language model. + """ return get_prefix(model_name) in VISION_MODELS -def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": - if finetuning_type not in PEFT_METHODS: - return gr.Dropdown(value=[], choices=[], interactive=False) - - adapters = [] - if model_name and finetuning_type == "lora": +def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown": + r""" + Lists all available checkpoints. + """ + checkpoints = [] + if model_name: save_dir = get_save_dir(model_name, finetuning_type) if save_dir and os.path.isdir(save_dir): - for adapter in os.listdir(save_dir): - if os.path.isdir(os.path.join(save_dir, adapter)) and any( - os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES + for checkpoint in os.listdir(save_dir): + if os.path.isdir(os.path.join(save_dir, checkpoint)) and any( + os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CHECKPOINT_NAMES ): - adapters.append(adapter) - return gr.Dropdown(value=[], choices=adapters, interactive=True) + checkpoints.append(checkpoint) + + if finetuning_type in PEFT_METHODS: + return gr.Dropdown(value=[], choices=checkpoints, multiselect=True) + else: + return gr.Dropdown(value=None, choices=checkpoints, multiselect=False) def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: + r""" + Loads dataset_info.json. + """ if dataset_dir == "ONLINE": logger.info("dataset_dir is ONLINE, using online dataset.") return {} @@ -145,12 +177,11 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: return {} -def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": +def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown": + r""" + Lists all available datasets in the dataset dir for the training stage. + """ dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] - return gr.Dropdown(value=[], choices=datasets) - - -def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button": - return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt")) + return gr.Dropdown(choices=datasets) diff --git a/src/llamafactory/webui/components/eval.py b/src/llamafactory/webui/components/eval.py index 8b70283b..99215fc2 100644 --- a/src/llamafactory/webui/components/eval.py +++ b/src/llamafactory/webui/components/eval.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Dict from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, list_dataset +from ..common import DEFAULT_DATA_DIR, list_datasets from .data import create_preview_box @@ -74,6 +74,6 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) + dataset.focus(list_datasets, [dataset_dir], [dataset], queue=False) return elem_dict diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index 134b77e0..2f354011 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -1,5 +1,6 @@ -from typing import TYPE_CHECKING, Dict, Generator, List +from typing import TYPE_CHECKING, Dict, Generator, List, Union +from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc from ...extras.packages import is_gradio_available from ...train.tuner import export_model @@ -24,8 +25,8 @@ def save_model( lang: str, model_name: str, model_path: str, - adapter_path: List[str], finetuning_type: str, + checkpoint_path: Union[str, List[str]], template: str, visual_inputs: bool, export_size: int, @@ -45,9 +46,9 @@ def save_model( error = ALERTS["err_no_export_dir"][lang] elif export_quantization_bit in GPTQ_BITS and not export_quantization_dataset: error = ALERTS["err_no_dataset"][lang] - elif export_quantization_bit not in GPTQ_BITS and not adapter_path: + elif export_quantization_bit not in GPTQ_BITS and not checkpoint_path: error = ALERTS["err_no_adapter"][lang] - elif export_quantization_bit in GPTQ_BITS and adapter_path: + elif export_quantization_bit in GPTQ_BITS and isinstance(checkpoint_path, list): error = ALERTS["err_gptq_lora"][lang] if error: @@ -55,16 +56,8 @@ def save_model( yield error return - if adapter_path: - adapter_name_or_path = ",".join( - [get_save_dir(model_name, finetuning_type, adapter) for adapter in adapter_path] - ) - else: - adapter_name_or_path = None - args = dict( model_name_or_path=model_path, - adapter_name_or_path=adapter_name_or_path, finetuning_type=finetuning_type, template=template, visual_inputs=visual_inputs, @@ -77,6 +70,14 @@ def save_model( export_legacy_format=export_legacy_format, ) + if checkpoint_path: + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path) + yield ALERTS["info_exporting"][lang] export_model(args) torch_gc() @@ -86,7 +87,7 @@ def save_model( def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): export_size = gr.Slider(minimum=1, maximum=100, value=1, step=1) - export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") + export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_device = gr.Radio(choices=["cpu", "cuda"], value="cpu") export_legacy_format = gr.Checkbox() @@ -104,8 +105,8 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: engine.manager.get_elem_by_id("top.lang"), engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.model_path"), - engine.manager.get_elem_by_id("top.adapter_path"), engine.manager.get_elem_by_id("top.finetuning_type"), + engine.manager.get_elem_by_id("top.checkpoint_path"), engine.manager.get_elem_by_id("top.template"), engine.manager.get_elem_by_id("top.visual_inputs"), export_size, diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index a75a4d62..ca093584 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Dict from ...data import templates from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available -from ..common import get_model_path, get_template, get_visual, list_adapters, save_config +from ..common import get_model_info, list_checkpoints, save_config from ..utils import can_quantize @@ -25,8 +25,7 @@ def create_top() -> Dict[str, "Component"]: with gr.Row(): finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) - adapter_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=5) - refresh_btn = gr.Button(scale=1) + checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) with gr.Accordion(open=False) as advanced_tab: with gr.Row(): @@ -36,27 +35,17 @@ def create_top() -> Dict[str, "Component"]: booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) visual_inputs = gr.Checkbox(scale=1) - model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( - get_model_path, [model_name], [model_path], queue=False - ).then(get_template, [model_name], [template], queue=False).then( - get_visual, [model_name], [visual_inputs], queue=False - ) # do not save config since the below line will save - + model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False) model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) - - finetuning_type.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( - can_quantize, [finetuning_type], [quantization_bit], queue=False - ) - - refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False) + finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False) + checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False) return dict( lang=lang, model_name=model_name, model_path=model_path, finetuning_type=finetuning_type, - adapter_path=adapter_path, - refresh_btn=refresh_btn, + checkpoint_path=checkpoint_path, advanced_tab=advanced_tab, quantization_bit=quantization_bit, template=template, diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 8db5c2ba..6f742bb1 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -5,8 +5,9 @@ from transformers.trainer_utils import SchedulerType from ...extras.constants import TRAINING_STAGES from ...extras.misc import get_device_count from ...extras.packages import is_gradio_available -from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset -from ..components.data import create_preview_box +from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets +from ..utils import change_stage, check_output_dir, list_output_dirs +from .data import create_preview_box if is_gradio_available(): @@ -256,11 +257,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: with gr.Row(): with gr.Column(scale=3): with gr.Row(): - output_dir = gr.Textbox() + initial_dir = gr.Textbox(visible=False, interactive=False) + output_dir = gr.Dropdown(allow_custom_value=True) config_path = gr.Textbox() with gr.Row(): - device_count = gr.Textbox(value=str(get_device_count()), interactive=False) + device_count = gr.Textbox(value=str(get_device_count() or 1), interactive=False) ds_stage = gr.Dropdown(choices=["none", "2", "3"], value="none") ds_offload = gr.Checkbox() @@ -282,6 +284,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: arg_load_btn=arg_load_btn, start_btn=start_btn, stop_btn=stop_btn, + initial_dir=initial_dir, output_dir=output_dir, config_path=config_path, device_count=device_count, @@ -295,24 +298,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) output_elems = [output_box, progress_bar, loss_viewer] + lang = engine.manager.get_elem_by_id("top.lang") + model_name = engine.manager.get_elem_by_id("top.model_name") + finetuning_type = engine.manager.get_elem_by_id("top.finetuning_type") + cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems, concurrency_limit=None) arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None) arg_load_btn.click( - engine.runner.load_args, - [engine.manager.get_elem_by_id("top.lang"), config_path], - list(input_elems) + [output_box], - concurrency_limit=None, + engine.runner.load_args, [lang, config_path], list(input_elems) + [output_box], concurrency_limit=None ) start_btn.click(engine.runner.run_train, input_elems, output_elems) stop_btn.click(engine.runner.set_abort) resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None) - dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) - training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False).then( - list_adapters, - [engine.manager.get_elem_by_id("top.model_name"), engine.manager.get_elem_by_id("top.finetuning_type")], - [reward_model], - queue=False, - ).then(autoset_packing, [training_stage], [packing], queue=False) + training_stage.change(change_stage, [training_stage], [dataset, packing], queue=False) + dataset.focus(list_datasets, [dataset_dir, training_stage], [dataset], queue=False) + reward_model.focus(list_checkpoints, [model_name, finetuning_type], [reward_model], queue=False) + output_dir.change( + list_output_dirs, [model_name, finetuning_type, initial_dir], [output_dir], concurrency_limit=None + ).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None) return elem_dict diff --git a/src/llamafactory/webui/engine.py b/src/llamafactory/webui/engine.py index fb568737..00877115 100644 --- a/src/llamafactory/webui/engine.py +++ b/src/llamafactory/webui/engine.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING, Any, Dict from .chatter import WebChatModel -from .common import get_model_path, list_dataset, load_config +from .common import load_config from .locales import LOCALES from .manager import Manager from .runner import Runner -from .utils import get_time, save_ds_config +from .utils import create_ds_config, get_time if TYPE_CHECKING: @@ -20,7 +20,7 @@ class Engine: self.runner = Runner(self.manager, demo_mode) self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat)) if not demo_mode: - save_ds_config() + create_ds_config() def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]: r""" @@ -40,16 +40,15 @@ class Engine: init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}} if not self.pure_chat: - init_dict["train.dataset"] = {"choices": list_dataset().choices} - init_dict["eval.dataset"] = {"choices": list_dataset().choices} - init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} - init_dict["train.config_path"] = {"value": "{}.yaml".format(get_time())} - init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} + current_time = get_time() + init_dict["train.initial_dir"] = {"value": "train_{}".format(current_time)} + init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)} + init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)} + init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)} init_dict["infer.image_box"] = {"visible": False} if user_config.get("last_model", None): init_dict["top.model_name"] = {"value": user_config["last_model"]} - init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])} yield self._update_component(init_dict) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 4657f9a3..5b11c853 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -46,26 +46,15 @@ LOCALES = { "label": "微调方法", }, }, - "adapter_path": { + "checkpoint_path": { "en": { - "label": "Adapter path", + "label": "Checkpoint path", }, "ru": { - "label": "Путь к адаптеру", + "label": "Путь контрольной точки", }, "zh": { - "label": "适配器路径", - }, - }, - "refresh_btn": { - "en": { - "value": "Refresh adapters", - }, - "ru": { - "value": "Обновить адаптеры", - }, - "zh": { - "value": "刷新适配器", + "label": "检查点路径", }, }, "advanced_tab": { @@ -1531,6 +1520,11 @@ ALERTS = { "ru": "Среда CUDA не обнаружена.", "zh": "未检测到 CUDA 环境。", }, + "warn_output_dir_exists": { + "en": "Output dir already exists, will resume training from here.", + "ru": "Выходной каталог уже существует, обучение будет продолжено отсюда.", + "zh": "输出目录已存在,将从该断点恢复训练。", + }, "info_aborting": { "en": "Aborted, wait for terminating...", "ru": "Прервано, ожидание завершения...", diff --git a/src/llamafactory/webui/manager.py b/src/llamafactory/webui/manager.py index f65fa804..326fdb8d 100644 --- a/src/llamafactory/webui/manager.py +++ b/src/llamafactory/webui/manager.py @@ -55,7 +55,7 @@ class Manager: self._id_to_elem["top.model_name"], self._id_to_elem["top.model_path"], self._id_to_elem["top.finetuning_type"], - self._id_to_elem["top.adapter_path"], + self._id_to_elem["top.checkpoint_path"], self._id_to_elem["top.quantization_bit"], self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index c2e46e97..7a305d62 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional import psutil from transformers.trainer import TRAINING_ARGS_NAME -from ..extras.constants import TRAINING_STAGES +from ..extras.constants import PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc from ..extras.packages import is_gradio_available -from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args +from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config from .locales import ALERTS -from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd +from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd if is_gradio_available(): @@ -85,26 +85,16 @@ class Runner: def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - args = dict( stage=TRAINING_STAGES[get("train.training_stage")], do_train=True, model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, - finetuning_type=get("top.finetuning_type"), + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, @@ -134,13 +124,23 @@ class Runner: report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), use_badam=get("train.use_badam"), - output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), + output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), fp16=(get("train.compute_type") == "fp16"), bf16=(get("train.compute_type") == "bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"), plot_loss=True, + ddp_timeout=180000000, ) + # checkpoints + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + # freeze config if args["finetuning_type"] == "freeze": args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") @@ -156,7 +156,7 @@ class Runner: args["create_new_adapter"] = get("train.create_new_adapter") args["use_rslora"] = get("train.use_rslora") args["use_dora"] = get("train.use_dora") - args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name")) + args["lora_target"] = get("train.lora_target") or get_module(model_name) args["additional_target"] = get("train.additional_target") or None if args["use_llama_pro"]: @@ -164,13 +164,14 @@ class Runner: # rlhf config if args["stage"] == "ppo": - args["reward_model"] = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("train.reward_model") - ] - ) - args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full" + if finetuning_type in PEFT_METHODS: + args["reward_model"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")] + ) + else: + args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model")) + + args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full" args["ppo_score_norm"] = get("train.ppo_score_norm") args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards") args["top_k"] = 0 @@ -211,25 +212,15 @@ class Runner: def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)] + model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type") user_config = load_config() - if get("top.adapter_path"): - adapter_name_or_path = ",".join( - [ - get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter) - for adapter in get("top.adapter_path") - ] - ) - else: - adapter_name_or_path = None - args = dict( stage="sft", model_name_or_path=get("top.model_path"), - adapter_name_or_path=adapter_name_or_path, cache_dir=user_config.get("cache_dir", None), preprocessing_num_workers=16, - finetuning_type=get("top.finetuning_type"), + finetuning_type=finetuning_type, quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, @@ -245,7 +236,7 @@ class Runner: max_new_tokens=get("eval.max_new_tokens"), top_p=get("eval.top_p"), temperature=get("eval.temperature"), - output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")), + output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), ) if get("eval.predict"): @@ -253,6 +244,14 @@ class Runner: else: args["do_eval"] = True + if get("top.checkpoint_path"): + if finetuning_type in PEFT_METHODS: # list + args["adapter_name_or_path"] = ",".join( + [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")] + ) + else: # str + args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path")) + return args def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]: @@ -296,9 +295,7 @@ class Runner: self.running = True get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)] - lang = get("top.lang") - model_name = get("top.model_name") - finetuning_type = get("top.finetuning_type") + lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type") output_dir = get("{}.output_dir".format("train" if self.do_train else "eval")) output_path = get_save_dir(model_name, finetuning_type, output_dir) @@ -356,7 +353,7 @@ class Runner: config_dict: Dict[str, Any] = {} lang = data[self.manager.get_elem_by_id("top.lang")] config_path = data[self.manager.get_elem_by_id("train.config_path")] - skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"] + skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"] for elem, value in data.items(): elem_id = self.manager.get_id_by_elem(elem) if elem_id not in skip_ids: diff --git a/src/llamafactory/webui/utils.py b/src/llamafactory/webui/utils.py index 654d1f8d..09cefa0e 100644 --- a/src/llamafactory/webui/utils.py +++ b/src/llamafactory/webui/utils.py @@ -3,12 +3,13 @@ import os from datetime import datetime from typing import Any, Dict, List, Optional, Tuple -from yaml import safe_dump +from transformers.trainer_utils import get_last_checkpoint +from yaml import safe_dump, safe_load -from ..extras.constants import RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG +from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import gen_loss_plot -from .common import DEFAULT_CACHE_DIR +from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir from .locales import ALERTS @@ -17,13 +18,26 @@ if is_gradio_available(): def can_quantize(finetuning_type: str) -> "gr.Dropdown": - if finetuning_type != "lora": + r""" + Judges if the quantization is available in this finetuning type. + """ + if finetuning_type not in PEFT_METHODS: return gr.Dropdown(value="none", interactive=False) else: return gr.Dropdown(interactive=True) +def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]: + r""" + Modifys states after changing the training stage. + """ + return [], TRAINING_STAGES[training_stage] == "pt" + + def check_json_schema(text: str, lang: str) -> None: + r""" + Checks if the json schema is valid. + """ try: tools = json.loads(text) if tools: @@ -38,11 +52,17 @@ def check_json_schema(text: str, lang: str) -> None: def clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]: + r""" + Removes args with NoneType or False or empty string value. + """ no_skip_keys = ["packing"] return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")} def gen_cmd(args: Dict[str, Any]) -> str: + r""" + Generates arguments for previewing. + """ cmd_lines = ["llamafactory-cli train "] for k, v in clean_cmd(args).items(): cmd_lines.append(" --{} {} ".format(k, str(v))) @@ -52,17 +72,39 @@ def gen_cmd(args: Dict[str, Any]) -> str: return cmd_text +def save_cmd(args: Dict[str, Any]) -> str: + r""" + Saves arguments to launch training. + """ + output_dir = args["output_dir"] + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f: + safe_dump(clean_cmd(args), f) + + return os.path.join(output_dir, TRAINER_CONFIG) + + def get_eval_results(path: os.PathLike) -> str: + r""" + Gets scores after evaluation. + """ with open(path, "r", encoding="utf-8") as f: result = json.dumps(json.load(f), indent=4) return "```json\n{}\n```\n".format(result) def get_time() -> str: + r""" + Gets current date and time. + """ return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S") def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Optional["gr.Plot"]]: + r""" + Gets training infomation for monitor. + """ running_log = "" running_progress = gr.Slider(visible=False) running_loss = None @@ -96,17 +138,56 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr return running_log, running_progress, running_loss -def save_cmd(args: Dict[str, Any]) -> str: - output_dir = args["output_dir"] - os.makedirs(output_dir, exist_ok=True) - - with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f: - safe_dump(clean_cmd(args), f) - - return os.path.join(output_dir, TRAINER_CONFIG) +def load_args(config_path: str) -> Optional[Dict[str, Any]]: + r""" + Loads saved arguments. + """ + try: + with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f: + return safe_load(f) + except Exception: + return None -def save_ds_config() -> None: +def save_args(config_path: str, config_dict: Dict[str, Any]) -> str: + r""" + Saves arguments. + """ + os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True) + with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f: + safe_dump(config_dict, f) + + return str(get_arg_save_path(config_path)) + + +def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown": + r""" + Lists all the directories that can resume from. + """ + output_dirs = [initial_dir] + if model_name: + save_dir = get_save_dir(model_name, finetuning_type) + if save_dir and os.path.isdir(save_dir): + for folder in os.listdir(save_dir): + output_dir = os.path.join(save_dir, folder) + if os.path.isdir(output_dir) and get_last_checkpoint(output_dir) is not None: + output_dirs.append(folder) + + return gr.Dropdown(choices=output_dirs) + + +def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None: + r""" + Check if output dir exists. + """ + if os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)): + gr.Warning(ALERTS["warn_output_dir_exists"][lang]) + + +def create_ds_config() -> None: + r""" + Creates deepspeed config. + """ os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) ds_config = { "train_batch_size": "auto",