From 8c1046d78ac6c8f9429b73617e35e1eccb35138f Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 16 Jun 2024 01:08:12 +0800 Subject: [PATCH] support pissa --- README.md | 6 +- README_zh.md | 6 +- examples/README.md | 6 ++ examples/README_zh.md | 6 ++ examples/extras/pissa/llama3_lora_sft.yaml | 42 ++++++++++ scripts/loftq_init.py | 72 +++++++---------- scripts/pissa_init.py | 79 ++++++++++++++++++ src/llamafactory/hparams/finetuning_args.py | 20 ++++- src/llamafactory/hparams/model_args.py | 8 +- src/llamafactory/hparams/parser.py | 5 +- src/llamafactory/model/adapter.py | 25 ++++-- src/llamafactory/train/dpo/trainer.py | 13 ++- src/llamafactory/train/pt/trainer.py | 12 ++- src/llamafactory/train/sft/trainer.py | 13 ++- src/llamafactory/train/trainer_utils.py | 54 ++++++++++++- src/llamafactory/webui/components/train.py | 9 ++- src/llamafactory/webui/locales.py | 14 ++++ src/llamafactory/webui/runner.py | 2 + tests/model/test_pissa.py | 90 +++++++++++++++++++++ 19 files changed, 406 insertions(+), 76 deletions(-) create mode 100644 examples/extras/pissa/llama3_lora_sft.yaml create mode 100644 scripts/pissa_init.py create mode 100644 tests/model/test_pissa.py diff --git a/README.md b/README.md index cae79694..cb9a7222 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. +- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog -[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models. +[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. -[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models. +[24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. diff --git a/README_zh.md b/README_zh.md index af3ff8f0..5c005f30 100644 --- a/README_zh.md +++ b/README_zh.md @@ -49,7 +49,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 -- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 +- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 -[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调。 +[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 -[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。 +[24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 diff --git a/examples/README.md b/examples/README.md index a6d78936..902d26b1 100644 --- a/examples/README.md +++ b/examples/README.md @@ -213,3 +213,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/single_node.sh ``` + +#### PiSSA Fine-Tuning + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +``` diff --git a/examples/README_zh.md b/examples/README_zh.md index b6168a95..586e498c 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -213,3 +213,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml ```bash bash examples/extras/fsdp_qlora/single_node.sh ``` + +#### PiSSA 微调 + +```bash +llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml +``` diff --git a/examples/extras/pissa/llama3_lora_sft.yaml b/examples/extras/pissa/llama3_lora_sft.yaml new file mode 100644 index 00000000..fd4b9f1d --- /dev/null +++ b/examples/extras/pissa/llama3_lora_sft.yaml @@ -0,0 +1,42 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +### method +stage: sft +do_train: true +finetuning_type: lora +lora_target: all +pissa_init: true +pissa_iter: 4 +pissa_convert: true + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/lora/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-4 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +fp16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/scripts/loftq_init.py b/scripts/loftq_init.py index 159dea06..556f342c 100644 --- a/scripts/loftq_init.py +++ b/scripts/loftq_init.py @@ -1,7 +1,7 @@ # coding=utf-8 # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by HuggingFace's PEFT library. +# This code is based on the HuggingFace's PEFT library. # https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,11 +17,9 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import fire -import torch -import torch.nn as nn from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer @@ -30,41 +28,20 @@ if TYPE_CHECKING: from transformers import PreTrainedModel -class Shell(nn.Module): - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = nn.Parameter(weight, requires_grad=False) - if bias is not None: - self.bias = nn.Parameter(bias, requires_grad=False) - - -def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: - for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}: - parent_name = ".".join(name.split(".")[:-1]) - child_name = name.split(".")[-1] - parent_module = model.get_submodule(parent_name) - child_module = getattr(parent_module, child_name) - base_layer = getattr(child_module, "base_layer") - weight = getattr(base_layer, "weight", None) - bias = getattr(base_layer, "bias", None) - setattr(parent_module, child_name, Shell(weight, bias)) - - print("Model unwrapped.") - - def quantize_loftq( model_name_or_path: str, - save_dir: str, - loftq_bits: Optional[int] = 4, - loftq_iter: Optional[int] = 1, - lora_alpha: Optional[int] = None, - lora_rank: Optional[int] = 16, - lora_target: Optional[str] = "q_proj,v_proj", - save_safetensors: Optional[bool] = False, + output_dir: str, + loftq_bits: int = 4, + loftq_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: str = "q_proj,v_proj", + save_safetensors: bool = True, ): r""" Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) - Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir + Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir """ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") @@ -74,25 +51,34 @@ def quantize_loftq( inference_mode=True, r=lora_rank, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, - lora_dropout=0.1, + lora_dropout=lora_dropout, target_modules=[name.strip() for name in lora_target.split(",")], init_lora_weights="loftq", loftq_config=loftq_config, ) # Init LoftQ model - lora_model = get_peft_model(model, lora_config) - base_model: "PreTrainedModel" = lora_model.get_base_model() + print("Initializing LoftQ weights, it may be take several minutes, wait patiently.") + peft_model = get_peft_model(model, lora_config) + loftq_dir = os.path.join(output_dir, "loftq_init") # Save LoftQ model - setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) - setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) - lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) + setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir) + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again + peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(loftq_dir)) # Save base model - unwrap_model(base_model) - base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) - tokenizer.save_pretrained(save_dir) + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(loftq_dir)) + print("finetuning_type: lora") + print("quantization_bit: {}".format(loftq_bits)) if __name__ == "__main__": diff --git a/scripts/pissa_init.py b/scripts/pissa_init.py new file mode 100644 index 00000000..1b673c45 --- /dev/null +++ b/scripts/pissa_init.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is based on the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import TYPE_CHECKING + +import fire +from peft import LoraConfig, TaskType, get_peft_model +from transformers import AutoModelForCausalLM, AutoTokenizer + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + +def quantize_pissa( + model_name_or_path: str, + output_dir: str, + pissa_iter: int = 4, + lora_alpha: int = None, + lora_rank: int = 16, + lora_dropout: float = 0, + lora_target: str = "q_proj,v_proj", + save_safetensors: bool = True, +): + r""" + Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA) + Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir + """ + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=lora_rank, + lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, + lora_dropout=lora_dropout, + target_modules=[name.strip() for name in lora_target.split(",")], + init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter) + ) + + # Init PiSSA model + peft_model = get_peft_model(model, lora_config) + pissa_dir = os.path.join(output_dir, "pissa_init") + + # Save PiSSA model + setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again + peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors) + print("Adapter weights saved in {}".format(pissa_dir)) + + # Save base model + base_model: "PreTrainedModel" = peft_model.unload() + base_model.save_pretrained(output_dir, safe_serialization=save_safetensors) + tokenizer.save_pretrained(output_dir) + print("Model weights saved in {}".format(output_dir)) + + print("Fine-tune this model with:") + print("model_name_or_path: {}".format(output_dir)) + print("adapter_name_or_path: {}".format(pissa_dir)) + print("finetuning_type: lora") + print("pissa_convert: true") + + +if __name__ == "__main__": + fire.Fire(quantize_pissa) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 52dc299e..1ef46eca 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -108,6 +108,18 @@ class LoraArguments: default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=4, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) create_new_adapter: bool = field( default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, @@ -340,7 +352,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.galore_target: List[str] = split_arg(self.galore_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only - self.use_ref_model = self.pref_loss not in ["orpo", "simpo"] + self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]) assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." @@ -367,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + if self.pissa_convert and self.finetuning_type != "lora": + raise ValueError("`pissa_convert` is only valid for LoRA training.") + + if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model): + raise ValueError("Cannot use PiSSA for current training stage.") + if self.train_mm_proj_only and self.finetuning_type != "full": raise ValueError("`train_mm_proj_only` is only valid for full training.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 53bdbdf2..996e9130 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,6 +1,6 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by HuggingFace's transformers library. +# This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -45,6 +45,10 @@ class ModelArguments: ) }, ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, @@ -150,7 +154,7 @@ class ModelArguments: metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, ) vllm_max_lora_rank: int = field( - default=8, + default=32, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, ) offload_folder: str = field( diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 1c57567c..31a805f6 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -1,6 +1,6 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by HuggingFace's transformers library. +# This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -90,6 +90,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") + if finetuning_args.use_pissa: + raise ValueError("Please use scripts/pissa_init.py for quantized PiSSA.") + if model_args.resize_vocab: raise ValueError("Cannot resize embedding layers of a quantized model.") diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index dfa71525..a8f3a256 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -179,8 +179,16 @@ def _setup_lora_tuning( else: adapter_to_merge = model_args.adapter_name_or_path + init_kwargs = { + "subfolder": model_args.adapter_folder, + "offload_folder": model_args.offload_folder, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + for adapter in adapter_to_merge: - model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder) + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) model = model.merge_and_unload() if len(adapter_to_merge) > 0: @@ -190,12 +198,7 @@ def _setup_lora_tuning( if model_args.use_unsloth: model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) else: - model = PeftModel.from_pretrained( - model, - adapter_to_resume, - is_trainable=is_trainable, - offload_folder=model_args.offload_folder, - ) + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) @@ -242,6 +245,14 @@ def _setup_lora_tuning( if model_args.use_unsloth: model = get_unsloth_peft_model(model, model_args, peft_kwargs) else: + if finetuning_args.pissa_init: + if finetuning_args.pissa_iter == -1: + logger.info("Using PiSSA initialization.") + peft_kwargs["init_lora_weights"] = "pissa" + else: + logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) + peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) + lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 475d08c3..9928d0bc 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -1,6 +1,6 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by HuggingFace's TRL library. +# This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import warnings from collections import defaultdict from contextlib import nullcontext @@ -28,7 +29,7 @@ from trl import DPOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -91,6 +92,9 @@ class CustomDPOTrainer(DPOTrainer): self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model.eval() + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -109,8 +113,11 @@ class CustomDPOTrainer(DPOTrainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 09729f2e..f9e04cb5 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from types import MethodType from typing import TYPE_CHECKING, Dict, Optional from transformers import Trainer from ...extras.logging import get_logger -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -42,6 +43,10 @@ class CustomTrainer(Trainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor + + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -60,6 +65,9 @@ class CustomTrainer(Trainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 6ab6914e..921e49ab 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -1,6 +1,6 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by HuggingFace's transformers library. +# This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,7 +26,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -51,6 +51,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args self.processor = processor + + if finetuning_args.pissa_convert: + self.save_model(os.path.join(self.args.output_dir, "pissa_init")) + if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -69,8 +73,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) + output_dir = output_dir if output_dir is not None else self.args.output_dir + if self.finetuning_args.pissa_convert: + convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args) + if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) def prediction_step( diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 5621d5df..2d6bab24 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -1,9 +1,9 @@ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # -# This code is inspired by the GaLore's implementation: https://github.com/jiaweizzhao/GaLore -# and the LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus -# and the BAdam's implementation: https://github.com/Ledzy/BAdam -# and the TRL's implementation: https://github.com/huggingface/trl +# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore +# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus +# and the original BAdam's implementation: https://github.com/Ledzy/BAdam +# and the HuggingFace's TRL library: https://github.com/huggingface/trl # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch +from peft import PeftModel from transformers import Trainer from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS @@ -37,6 +39,7 @@ if is_galore_available(): if TYPE_CHECKING: + from accelerate import Accelerator from transformers import PreTrainedModel, Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead @@ -171,6 +174,49 @@ def create_reward_model( return reward_model +def convert_pissa_adapter( + output_dir: str, + state_dict: Dict[str, "torch.Tensor"], + accelerator: "Accelerator", + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", +) -> None: + r""" + Converts the PiSSA adapter to a LoRA adapter. + """ + pissa_init_dir = os.path.join(training_args.output_dir, "pissa_init") + pissa_backup_dir = os.path.join(output_dir, "pissa_backup") + if output_dir == pissa_init_dir: + logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir)) + unwrapped_model = accelerator.unwrap_model(model) + if isinstance(unwrapped_model, PeftModel): + init_lora_weights = getattr(unwrapped_model.peft_config["default"], "init_lora_weights") + setattr(unwrapped_model.peft_config["default"], "init_lora_weights", True) + unwrapped_model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + ) + setattr(unwrapped_model.peft_config["default"], "init_lora_weights", init_lora_weights) + elif output_dir == training_args.output_dir: # at the end of training + logger.info("Converted PiSSA adapter will be saved at: {}.".format(output_dir)) + unwrapped_model = accelerator.unwrap_model(model) + if isinstance(unwrapped_model, PeftModel): # backup the pissa adapter for further use + unwrapped_model.save_pretrained( + pissa_backup_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + ) + unwrapped_model.save_pretrained( + output_dir, + state_dict=state_dict, + safe_serialization=training_args.save_safetensors, + convert_pissa_to_lora=pissa_init_dir, + ) + unwrapped_model.load_adapter(pissa_backup_dir, "default", is_trainable=True) + unwrapped_model.set_adapter("default") + + def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: r""" Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 673f6bf4..874f3c5e 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -163,10 +163,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter = gr.Checkbox() with gr.Row(): - with gr.Column(scale=1): - use_rslora = gr.Checkbox() - use_dora = gr.Checkbox() - + use_rslora = gr.Checkbox() + use_dora = gr.Checkbox() + use_pissa = gr.Checkbox() lora_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2) @@ -179,6 +178,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter, use_rslora, use_dora, + use_pissa, lora_target, additional_target, } @@ -193,6 +193,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: create_new_adapter=create_new_adapter, use_rslora=use_rslora, use_dora=use_dora, + use_pissa=use_pissa, lora_target=lora_target, additional_target=additional_target, ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 427f01b8..8e8d6fce 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -732,6 +732,20 @@ LOCALES = { "info": "使用权重分解的 LoRA。", }, }, + "use_pissa": { + "en": { + "label": "Use PiSSA", + "info": "Use PiSSA method.", + }, + "ru": { + "label": "используйте PiSSA", + "info": "Используйте метод PiSSA.", + }, + "zh": { + "label": "使用 PiSSA", + "info": "使用 PiSSA 方法。", + }, + }, "lora_target": { "en": { "label": "LoRA modules (optional)", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 76982934..13dbba03 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -173,6 +173,8 @@ class Runner: args["create_new_adapter"] = get("train.create_new_adapter") args["use_rslora"] = get("train.use_rslora") args["use_dora"] = get("train.use_dora") + args["pissa_init"] = get("train.use_pissa") + args["pissa_convert"] = get("train.use_pissa") args["lora_target"] = get("train.lora_target") or "all" args["additional_target"] = get("train.additional_target") or None diff --git a/tests/model/test_pissa.py b/tests/model/test_pissa.py new file mode 100644 index 00000000..70c424fd --- /dev/null +++ b/tests/model/test_pissa.py @@ -0,0 +1,90 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from peft import LoraModel, PeftModel +from transformers import AutoModelForCausalLM + +from llamafactory.extras.misc import get_current_device +from llamafactory.hparams import get_infer_args, get_train_args +from llamafactory.model import load_model, load_tokenizer + + +TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") + +TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa") + +TRAIN_ARGS = { + "model_name_or_path": TINY_LLAMA, + "stage": "sft", + "do_train": True, + "finetuning_type": "lora", + "pissa_init": True, + "pissa_iter": -1, + "dataset": "llamafactory/tiny-supervised-dataset", + "dataset_dir": "ONLINE", + "template": "llama3", + "cutoff_len": 1024, + "overwrite_cache": True, + "output_dir": "dummy_dir", + "overwrite_output_dir": True, + "fp16": True, +} + +INFER_ARGS = { + "model_name_or_path": TINY_LLAMA_PISSA, + "adapter_name_or_path": TINY_LLAMA_PISSA, + "adapter_folder": "pissa_init", + "finetuning_type": "lora", + "template": "llama3", + "infer_dtype": "float16", +} + + +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"): + state_dict_a = model_a.state_dict() + state_dict_b = model_b.state_dict() + assert set(state_dict_a.keys()) == set(state_dict_b.keys()) + for name in state_dict_a.keys(): + assert torch.allclose(state_dict_a[name], state_dict_b[name]) + + +def test_pissa_init(): + model_args, _, _, finetuning_args, _ = get_train_args(TRAIN_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init", is_trainable=True) + for param in filter(lambda p: p.requires_grad, ref_model.parameters()): + param.data = param.data.to(torch.float32) + + compare_model(model, ref_model) + + +def test_pissa_inference(): + model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS) + tokenizer_module = load_tokenizer(model_args) + model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=False) + + base_model = AutoModelForCausalLM.from_pretrained( + TINY_LLAMA_PISSA, torch_dtype=torch.float16, device_map=get_current_device() + ) + ref_model: "LoraModel" = PeftModel.from_pretrained(base_model, TINY_LLAMA_PISSA, subfolder="pissa_init") + ref_model = ref_model.merge_and_unload() + compare_model(model, ref_model)