From 7aad0b889d9a316fffd65f32a419078418fc0986 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 23 Dec 2023 00:14:33 +0800 Subject: [PATCH] support unsloth --- README.md | 2 + README_zh.md | 2 + requirements.txt | 2 +- src/llmtuner/hparams/finetuning_args.py | 41 +---- src/llmtuner/hparams/model_args.py | 32 ++++ src/llmtuner/model/adapter.py | 35 +++-- src/llmtuner/model/loader.py | 60 ++++--- src/llmtuner/model/parser.py | 3 +- src/llmtuner/model/patcher.py | 198 +++++++++++++----------- src/llmtuner/train/dpo/trainer.py | 14 +- src/llmtuner/train/tuner.py | 6 +- 11 files changed, 224 insertions(+), 171 deletions(-) diff --git a/README.md b/README.md index deb70385..c4b60fea 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation for faster LoRA tuning. Try `--use_unsloth` argument to active unsloth patch. See performance comparisons [here](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). + [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage. diff --git a/README_zh.md b/README_zh.md index 49bf8727..f5155747 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ## 更新日志 +[23/12/23] 我们支持了使用 **[unsloth](https://github.com/unslothai/unsloth)** 大幅提升 LoRA 训练效率。请使用 `--use_unsloth` 参数启用 unsloth 优化。性能对比请查阅[此处](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 + [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 diff --git a/requirements.txt b/requirements.txt index 35497667..07d11c42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.36.2 datasets>=2.14.3 accelerate>=0.21.0 peft>=0.7.0 -trl==0.7.4 +trl>=0.7.6 gradio>=3.38.0,<4.0.0 scipy sentencepiece diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 1730a6b9..be9f9474 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -38,7 +38,7 @@ class LoraArguments: metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."} ) lora_dropout: Optional[float] = field( - default=0.1, + default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."} ) lora_rank: Optional[int] = field( @@ -70,7 +70,7 @@ class RLHFArguments: default=0.1, metadata={"help": "The beta parameter for the DPO loss."} ) - dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field( + dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( default="sigmoid", metadata={"help": "The type of DPO loss to use."} ) @@ -133,38 +133,7 @@ class RLHFArguments: @dataclass -class ExportArguments: - r""" - Arguments pertaining to model exporting. - """ - export_dir: Optional[str] = field( - default=None, - metadata={"help": "Path to the directory to save the exported model."} - ) - export_size: Optional[int] = field( - default=1, - metadata={"help": "The file shard size (in GB) of the exported model."} - ) - export_quantization_bit: Optional[int] = field( - default=None, - metadata={"help": "The number of bits to quantize the exported model."} - ) - export_quantization_dataset: Optional[str] = field( - default=None, - metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} - ) - export_quantization_nsamples: Optional[int] = field( - default=128, - metadata={"help": "The number of samples used for quantization."} - ) - export_quantization_maxlen: Optional[int] = field( - default=1024, - metadata={"help": "The maximum length of the model inputs used for quantization."} - ) - - -@dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments): +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ @@ -201,7 +170,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." - assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." if self.stage == "ppo" and self.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") @@ -209,9 +177,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") - if self.export_quantization_bit is not None and self.export_quantization_dataset is None: - raise ValueError("Quantization dataset is necessary for exporting.") - def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index f46f291f..24ca5dc1 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -54,6 +54,10 @@ class ModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} ) + use_unsloth: Optional[bool] = field( + default=False, + metadata={"help": "Whether to use unsloth's optimization for LoRA training."} + ) hf_hub_token: Optional[str] = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."} @@ -62,6 +66,30 @@ class ModelArguments: default=None, metadata={"help": "Auth token to log in with ModelScope Hub."} ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."} + ) + export_size: Optional[int] = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."} + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."} + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} + ) + export_quantization_nsamples: Optional[int] = field( + default=128, + metadata={"help": "The number of samples used for quantization."} + ) + export_quantization_maxlen: Optional[int] = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."} + ) def __post_init__(self): self.compute_dtype = None @@ -74,6 +102,10 @@ class ModelArguments: self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." + + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") def to_dict(self) -> Dict[str, Any]: return asdict(self) diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index fe002a7f..e3e4da15 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -93,16 +93,31 @@ def init_adapter( else: target_modules = finetuning_args.lora_target - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=finetuning_args.lora_rank, - lora_alpha=finetuning_args.lora_alpha, - lora_dropout=finetuning_args.lora_dropout, - target_modules=target_modules, - modules_to_save=finetuning_args.additional_target - ) - model = get_peft_model(model, lora_config) + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout + } + + if model_args.use_unsloth: + from unsloth import FastLlamaModel, FastMistralModel # type: ignore + unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} + if getattr(model.config, "model_type", None) == "llama": + model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + elif getattr(model.config, "model_type", None) == "mistral": + model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + else: + raise NotImplementedError + + else: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + modules_to_save=finetuning_args.additional_target, + **peft_kwargs + ) + model = get_peft_model(model, lora_config) for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.float32) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index fe821224..0366ef27 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -4,10 +4,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead -import llmtuner.model.patcher as patcher from llmtuner.extras.logging import get_logger -from llmtuner.extras.misc import count_parameters, try_download_model_from_ms +from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms from llmtuner.model.adapter import init_adapter +from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model from llmtuner.model.utils import ( load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass ) @@ -24,7 +24,7 @@ require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36. require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") -require_version("trl==0.7.4", "To fix: pip install trl==0.7.4") +require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") def load_model_and_tokenizer( @@ -52,26 +52,48 @@ def load_model_and_tokenizer( model_args.model_name_or_path, use_fast=model_args.use_fast_tokenizer, split_special_tokens=model_args.split_special_tokens, - padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow + padding_side="right", **config_kwargs ) + patch_tokenizer(tokenizer) + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + patch_config(config, tokenizer, model_args, config_kwargs, is_trainable) - patcher.patch_tokenizer(tokenizer) - patcher.patch_config(config, model_args) - patcher.configure_rope(config, model_args, is_trainable) - patcher.configure_flashattn(config_kwargs, model_args) - patcher.configure_longlora(config, model_args, is_trainable) - patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args) + model = None + if is_trainable and model_args.use_unsloth: + require_version("unsloth==2023.12", "Follow the instructions at: https://github.com/unslothai/unsloth") + from unsloth import FastLlamaModel, FastMistralModel # type: ignore + unsloth_kwargs = { + "model_name": model_args.model_name_or_path, + "max_seq_length": model_args.model_max_length, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": get_current_device(), + "rope_scaling": getattr(config, "rope_scaling", None) + } + if getattr(config, "model_type", None) == "llama": + model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs) + elif getattr(config, "model_type", None) == "mistral": + model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs) + else: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model_args.use_unsloth = False + + if model_args.adapter_name_or_path: + model_args.adapter_name_or_path = None + logger.warning("Unsloth does not support loading adapters.") + + if model is None: + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + **config_kwargs + ) - model = AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, - config=config, - low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), - **config_kwargs - ) model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model - patcher.patch_model(model) + patch_model(model) register_autoclass(config, model, tokenizer) if not is_deepspeed_zero3_enabled(): resize_embedding_layer(model, tokenizer) @@ -81,7 +103,7 @@ def load_model_and_tokenizer( if add_valuehead: model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) - patcher.patch_valuehead_model(model) + patch_valuehead_model(model) if model_args.adapter_name_or_path is not None: vhead_path = model_args.adapter_name_or_path[-1] @@ -94,7 +116,7 @@ def load_model_and_tokenizer( logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) if not is_trainable: - model.requires_grad_(False) # fix all model params + model.requires_grad_(False) model.eval() else: model.train() diff --git a/src/llmtuner/model/parser.py b/src/llmtuner/model/parser.py index 760ce88e..063233dc 100644 --- a/src/llmtuner/model/parser.py +++ b/src/llmtuner/model/parser.py @@ -104,8 +104,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: # Setup logging if training_args.should_log: - log_level = training_args.get_process_log_level() - _set_transformers_logging(log_level) + _set_transformers_logging() # Check arguments data_args.init_for_training(training_args.seed) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 38887b70..80438a25 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -18,85 +18,14 @@ from llmtuner.extras.packages import is_flash_attn2_available if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedTokenizer from trl import AutoModelForCausalLMWithValueHead - from llmtuner.hparams import ModelArguments, FinetuningArguments + from llmtuner.hparams import ModelArguments logger = get_logger(__name__) SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama -def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"): - if model_args.flash_attn and is_flash_attn2_available(): - config_kwargs["use_flash_attention_2"] = True - logger.info("Using FlashAttention-2 for faster training and inference.") - - -def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): - if is_trainable and model_args.shift_attn: - if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: - setattr(config, "group_size_ratio", 0.25) - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") - - -def configure_quantization( - config: "PretrainedConfig", - config_kwargs: Dict[str, Any], - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - finetuning_args: "FinetuningArguments" -): - r""" - Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) - """ - if getattr(config, "quantization_config", None): # gptq or awq - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - - config_kwargs["device_map"] = {"": get_current_device()} - quantization_config = getattr(config, "quantization_config", None) - logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) - - elif finetuning_args.export_quantization_bit is not None: # gptq - require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") - require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") - from accelerate.utils import get_max_memory - - if getattr(config, "model_type", None) == "chatglm": - raise ValueError("ChatGLM model is not supported.") - - config_kwargs["quantization_config"] = GPTQConfig( - bits=finetuning_args.export_quantization_bit, - tokenizer=tokenizer, - dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args) - ) - config_kwargs["device_map"] = "auto" - config_kwargs["max_memory"] = get_max_memory() - logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit)) - - elif model_args.quantization_bit is not None: # bnb - if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - - elif model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) - - config_kwargs["device_map"] = {"": get_current_device()} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - - -def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): +def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): if model_args.rope_scaling is not None: if not hasattr(config, "rope_scaling"): logger.warning("Current model does not support RoPE scaling.") @@ -123,27 +52,94 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ )) -def get_quantization_dataset( +def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]): + if model_args.flash_attn and is_flash_attn2_available(): + config_kwargs["use_flash_attention_2"] = True + config_kwargs["torch_dtype"] = model_args.compute_dtype + logger.info("Using FlashAttention-2 for faster training and inference.") + + +def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool): + if is_trainable and model_args.shift_attn: + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") + + +def _configure_quantization( + config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", - finetuning_args: "FinetuningArguments" -) -> List[str]: + config_kwargs: Dict[str, Any] +): + r""" + Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) + """ + if getattr(config, "quantization_config", None): # gptq or awq + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + config_kwargs["device_map"] = {"": get_current_device()} + quantization_config = getattr(config, "quantization_config", None) + logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) + + elif model_args.export_quantization_bit is not None: # gptq + require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported.") + + config_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + tokenizer=tokenizer, + dataset=get_quantization_dataset(tokenizer, model_args) + ) + config_kwargs["device_map"] = "auto" + config_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + + config_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + + +def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: r""" Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 """ - if os.path.isfile(finetuning_args.export_quantization_dataset): - data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None) - data_files = finetuning_args.export_quantization_dataset + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset else: - data_path = finetuning_args.export_quantization_dataset + data_path = model_args.export_quantization_dataset data_files = None dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) - maxlen = finetuning_args.export_quantization_maxlen + maxlen = model_args.export_quantization_maxlen samples = [] - for _ in range(finetuning_args.export_quantization_nsamples): + for _ in range(model_args.export_quantization_nsamples): while True: sample_idx = random.randint(0, len(dataset) - 1) sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") @@ -151,13 +147,24 @@ def get_quantization_dataset( break # TODO: fix large maxlen word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) - input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen] + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) return samples -def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"): +def patch_tokenizer(tokenizer: "PreTrainedTokenizer"): + if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + + +def patch_config( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + config_kwargs: Dict[str, Any], + is_trainable: bool +): if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) setattr(config, "torch_dtype", model_args.compute_dtype) @@ -166,6 +173,11 @@ def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"): for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) + _configure_rope(config, model_args, is_trainable) + _configure_flashattn(model_args, config_kwargs) + _configure_longlora(config, model_args, is_trainable) + _configure_quantization(config, tokenizer, model_args, config_kwargs) + def patch_model(model: "PreTrainedModel"): if "GenerationMixin" not in str(model.generate.__func__): @@ -177,15 +189,15 @@ def patch_model(model: "PreTrainedModel"): def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"): - def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: - return self.pretrained_model.get_input_embeddings() + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() - setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] setattr(model, "_keys_to_ignore_on_save", ignore_modules) - setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method - - -def patch_tokenizer(tokenizer: "PreTrainedTokenizer"): - if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): - tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + setattr(model, "tie_weights", MethodType(tie_weights, model)) + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 4de79d82..7fafd98a 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -16,7 +16,7 @@ class CustomDPOTrainer(DPOTrainer): def __init__( self, beta: float, - loss_type: Literal["sigmoid", "hinge"], + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"], ftx_gamma: float, model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, @@ -28,16 +28,20 @@ class CustomDPOTrainer(DPOTrainer): if ref_model is not None: disable_dropout_in_model(ref_model) - self.is_encoder_decoder = model.config.is_encoder_decoder - self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning self.generate_during_eval = False # disable at evaluation self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 + self.is_encoder_decoder = model.config.is_encoder_decoder + self.precompute_ref_log_probs = False + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + self.ref_model = ref_model self.beta = beta self.label_smoothing = 0 - self.ftx_gamma = ftx_gamma self.loss_type = loss_type + self.ftx_gamma = ftx_gamma self._stored_metrics = defaultdict(lambda: defaultdict(list)) Trainer.__init__(self, model=model, **kwargs) @@ -95,7 +99,7 @@ class CustomDPOTrainer(DPOTrainer): chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) return chosen_logps, rejected_logps, chosen_logits, rejected_logits - def get_batch_metrics( + def get_batch_loss_metrics( self, model: "PreTrainedModel", batch: Dict[str, torch.Tensor], diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index 66b53877..0af124a3 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -37,7 +37,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra def export_model(args: Optional[Dict[str, Any]] = None): model_args, _, finetuning_args, _ = get_infer_args(args) - if model_args.adapter_name_or_path is not None and finetuning_args.export_quantization_bit is not None: + if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: raise ValueError("Please merge adapters before quantizing the model.") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) @@ -47,12 +47,12 @@ def export_model(args: Optional[Dict[str, Any]] = None): model.config.use_cache = True model = model.to("cpu") - model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size)) + model.save_pretrained(model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size)) try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" - tokenizer.save_pretrained(finetuning_args.export_dir) + tokenizer.save_pretrained(model_args.export_dir) except: logger.warning("Cannot save tokenizer, please copy the files manually.")