diff --git a/README.md b/README.md index b50571fa..7adfc66c 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | -- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options. +- **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules. - For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models. ## Supported Training Approaches @@ -176,7 +176,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_llama_model \ --do_train \ --dataset wiki_demo \ - --template default \ --finetuning_type lora \ --lora_target q_proj,v_proj \ --output_dir path_to_pt_checkpoint \ diff --git a/README_zh.md b/README_zh.md index 28caeab2..e8c824d7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -62,7 +62,7 @@ | [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse | | [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 | -- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。 +- **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。 - 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。 ## 训练方法 @@ -176,7 +176,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --model_name_or_path path_to_llama_model \ --do_train \ --dataset wiki_demo \ - --template default \ --finetuning_type lora \ --lora_target q_proj,v_proj \ --output_dir path_to_pt_checkpoint \ diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 175cb6e9..f9ee2bea 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,10 +1,8 @@ import gc import torch -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList -from llmtuner.extras.constants import LAYERNORM_NAMES - if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -29,12 +27,6 @@ class AverageMeter: self.avg = self.sum / self.count -def get_logits_processor() -> LogitsProcessorList: - logits_processor = LogitsProcessorList() - logits_processor.append(InfNanRemoveLogitsProcessor()) - return logits_processor - - def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: r""" Returns the number of trainable parameters and number of all parameters in the model. @@ -57,42 +49,10 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param -# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 -# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 -def prepare_model_for_training( - model: "PreTrainedModel", - finetuning_type: str, - output_layer_name: Optional[str] = "lm_head", - use_gradient_checkpointing: Optional[bool] = True, - layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES -) -> "PreTrainedModel": - for name, param in model.named_parameters(): - if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): - param.data = param.data.to(torch.float32) - - if use_gradient_checkpointing: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - model.gradient_checkpointing_enable() - model.config.use_cache = False # turn off when gradient checkpointing is enabled - - if finetuning_type != "full" and hasattr(model, output_layer_name): - output_layer: torch.nn.Linear = getattr(model, output_layer_name) - input_dtype = output_layer.weight.dtype - - class CastOutputToFloat(torch.nn.Sequential): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x.to(input_dtype)).to(torch.float32) - - setattr(model, output_layer_name, CastOutputToFloat(output_layer)) - - return model +def get_logits_processor() -> LogitsProcessorList: + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor def torch_gc() -> None: diff --git a/src/llmtuner/extras/save_and_load.py b/src/llmtuner/extras/save_and_load.py index 32dc651c..af66248d 100644 --- a/src/llmtuner/extras/save_and_load.py +++ b/src/llmtuner/extras/save_and_load.py @@ -42,8 +42,8 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) - logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) return False valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") - model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) - model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) - model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) - model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) + model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False) + model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False) + model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False) return True diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 5db56876..1635a2b7 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -12,6 +12,7 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME from llmtuner.extras.logging import get_logger from llmtuner.extras.save_and_load import load_trainable_params +from llmtuner.tuner.core.utils import find_all_linear_modules if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -81,13 +82,18 @@ def init_adapter( model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) if is_trainable and latest_checkpoint is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target == "all": + target_modules = find_all_linear_modules(model, model_args.quantization_bit) + else: + target_modules = finetuning_args.lora_target + lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=finetuning_args.lora_rank, lora_alpha=finetuning_args.lora_alpha, lora_dropout=finetuning_args.lora_dropout, - target_modules=finetuning_args.lora_target + target_modules=target_modules ) model = get_peft_model(model, lora_config) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 3fffade4..7941e909 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -23,10 +23,11 @@ except ImportError: from transformers.integrations import is_deepspeed_zero3_enabled from llmtuner.extras.logging import reset_logging, get_logger -from llmtuner.extras.misc import count_parameters, prepare_model_for_training +from llmtuner.extras.misc import count_parameters from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.hparams import FinetuningArguments from llmtuner.tuner.core.adapter import init_adapter +from llmtuner.tuner.core.utils import prepare_model_for_training if TYPE_CHECKING: from transformers import PreTrainedTokenizer diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py new file mode 100644 index 00000000..b40ce893 --- /dev/null +++ b/src/llmtuner/tuner/core/utils.py @@ -0,0 +1,72 @@ +import torch +from typing import TYPE_CHECKING, List, Optional + +from llmtuner.extras.constants import LAYERNORM_NAMES + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + +def find_all_linear_modules( + model: "PreTrainedModel", + quantization_bit: Optional[int] = None, + output_layer_name: Optional[str] = "lm_head" +) -> List[str]: + if quantization_bit is not None: + import bitsandbytes as bnb + linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt + else: + linear_cls = torch.nn.Linear + + module_names = set() + for name, module in model.named_modules(): + if isinstance(module, linear_cls): + module_names.add(name.split(".")[-1]) + + if output_layer_name in module_names: + module_names.pop(output_layer_name) + + return list(module_names) + + +def prepare_model_for_training( + model: "PreTrainedModel", + finetuning_type: str, + output_layer_name: Optional[str] = "lm_head", + use_gradient_checkpointing: Optional[bool] = True, + layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES +) -> "PreTrainedModel": + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) upcast the lm_head to fp32 + Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33 + """ + for name, param in model.named_parameters(): + if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): + param.data = param.data.to(torch.float32) + + if use_gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model.gradient_checkpointing_enable() + model.config.use_cache = False # turn off when gradient checkpointing is enabled + + if finetuning_type != "full" and hasattr(model, output_layer_name): + output_layer: torch.nn.Linear = getattr(model, output_layer_name) + input_dtype = output_layer.weight.dtype + + class CastOutputToFloat(torch.nn.Sequential): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x.to(input_dtype)).to(torch.float32) + + setattr(model, output_layer_name, CastOutputToFloat(output_layer)) + + return model diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index c5f22294..95d7b613 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -16,8 +16,8 @@ USER_CONFIG = "user.config" DATA_CONFIG = "dataset_info.json" -def get_save_dir(model_name: str) -> str: - return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1]) +def get_save_dir(*args) -> os.PathLike: + return os.path.join(DEFAULT_SAVE_DIR, *args) def get_config_path() -> os.PathLike: @@ -29,7 +29,7 @@ def load_config() -> Dict[str, Any]: with open(get_config_path(), "r", encoding="utf-8") as f: return json.load(f) except: - return {"lang": "", "last_model": "", "path_dict": {}} + return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} def save_config(lang: str, model_name: str, model_path: str) -> None: @@ -56,7 +56,7 @@ def get_template(model_name: str) -> str: def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: checkpoints = [] - save_dir = os.path.join(get_save_dir(model_name), finetuning_type) + save_dir = get_save_dir(model_name, finetuning_type) if save_dir and os.path.isdir(save_dir): for checkpoint in os.listdir(save_dir): if ( diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 14aef162..489ccf2e 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -16,7 +16,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component" info_box = gr.Textbox(show_label=False, interactive=False) - chat_model = WebChatModel() + chat_model = WebChatModel(lazy_init=True) chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) load_btn.click( diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 9b46ed9d..f1df34e3 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -12,7 +12,7 @@ from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc from llmtuner.tuner import run_exp -from llmtuner.webui.common import get_model_path, get_save_dir +from llmtuner.webui.common import get_model_path, get_save_dir, load_config from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar @@ -97,21 +97,25 @@ class Runner: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: checkpoint_dir = ",".join( - [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints] + [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] ) else: checkpoint_dir = None - output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir) + output_dir = get_save_dir(model_name, finetuning_type, output_dir) + + user_config = load_config() + cache_dir = user_config.get("cache_dir", None) args = dict( stage=TRAINING_STAGES[training_stage], model_name_or_path=get_model_path(model_name), do_train=True, - overwrite_cache=True, + overwrite_cache=False, + cache_dir=cache_dir, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None, + quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, template=template, system_prompt=system_prompt, dataset_dir=dataset_dir, @@ -172,22 +176,26 @@ class Runner: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]: if checkpoints: checkpoint_dir = ",".join( - [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] ) - output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints)) + output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints)) else: checkpoint_dir = None - output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") + output_dir = get_save_dir(model_name, finetuning_type, "eval_base") + + user_config = load_config() + cache_dir = user_config.get("cache_dir", None) args = dict( stage="sft", model_name_or_path=get_model_path(model_name), do_eval=True, - overwrite_cache=True, + overwrite_cache=False, predict_with_generate=True, + cache_dir=cache_dir, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, - quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None, + quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, template=template, system_prompt=system_prompt, dataset_dir=dataset_dir, diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 56868b49..52016378 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -90,7 +90,7 @@ def get_eval_results(path: os.PathLike) -> str: def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: - log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl") + log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") if not os.path.isfile(log_file): return None @@ -139,7 +139,7 @@ def save_model( return checkpoint_dir = ",".join( - [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] ) if not save_dir: