diff --git a/README.md b/README.md index 9f33b539..a0588a5a 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ Choose your path: - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. -- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning. +- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm. + [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer. [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) diff --git a/README_zh.md b/README_zh.md index 2dfb3771..24ba3e12 100644 --- a/README_zh.md +++ b/README_zh.md @@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd - **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。 - **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 -- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。 +- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd ## 更新日志 +[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。 + [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。) diff --git a/examples/extras/lora_plus/sft.sh b/examples/extras/loraplus/sft.sh similarity index 91% rename from examples/extras/lora_plus/sft.sh rename to examples/extras/loraplus/sft.sh index fff1097e..8bc16cdf 100644 --- a/examples/extras/lora_plus/sft.sh +++ b/examples/extras/loraplus/sft.sh @@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ --template default \ --finetuning_type lora \ --lora_target q_proj,v_proj \ - --output_dir ../../saves/LLaMA2-7B/lora_plus/sft \ + --output_dir ../../saves/LLaMA2-7B/loraplus/sft \ --overwrite_cache \ --overwrite_output_dir \ --cutoff_len 1024 \ @@ -30,4 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ --val_size 0.1 \ --plot_loss \ --fp16 \ - --lora_lr_ratio 16.0 + --loraplus_lr_ratio 16.0 diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 424e5751..8188fdcc 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -57,7 +57,7 @@ class LoraArguments: metadata={ "help": """Name(s) of target modules to apply LoRA. \ Use commas to separate multiple modules. \ - Use "all" to specify all the available modules. \ + Use "all" to specify all the linear modules. \ LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ @@ -66,6 +66,14 @@ class LoraArguments: Others choices: the same as LLaMA.""" }, ) + loraplus_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, + ) + loraplus_lr_embedding: float = field( + default=1e-6, + metadata={"help": "LoRA plus learning rate for lora embedding layers."}, + ) use_rslora: bool = field( default=False, metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, @@ -163,8 +171,11 @@ class GaloreArguments: metadata={"help": "Whether or not to use gradient low-Rank projection."}, ) galore_target: str = field( - default="mlp,attn", - metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."}, + default="all", + metadata={ + "help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \ + Use "all" to specify all the linear modules.""" + }, ) galore_rank: int = field( default=16, @@ -210,11 +221,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, ) - # for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf) - lora_lr_ratio: Optional[float] = field( - default=None, - metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'}, - ) plot_loss: bool = field( default=False, metadata={"help": "Whether or not to save the training loss curves."}, @@ -230,6 +236,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA self.lora_alpha = self.lora_alpha or self.lora_rank * 2 self.lora_target = split_arg(self.lora_target) self.additional_target = split_arg(self.additional_target) + self.galore_target = split_arg(self.galore_target) assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index bb7c4db9..4b1b26fc 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,5 +1,5 @@ from .loader import load_model, load_model_and_tokenizer, load_tokenizer -from .utils import load_valuehead_params +from .utils import find_all_linear_modules, load_valuehead_params __all__ = [ @@ -7,4 +7,5 @@ __all__ = [ "load_model_and_tokenizer", "load_tokenizer", "load_valuehead_params", + "find_all_linear_modules", ] diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 81f0b7f6..eb6d3878 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger -from .utils import find_all_linear_modules, find_expanded_modules +from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules if TYPE_CHECKING: @@ -129,9 +129,9 @@ def init_adapter( if finetuning_args.use_llama_pro: target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) - if finetuning_args.use_dora: - if getattr(model, "quantization_method", None): - raise ValueError("DoRA is currently not compatible with quantized models.") + if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None: + if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES: + raise ValueError("DoRA is not compatible with PTQ-quantized models.") peft_kwargs = { "r": finetuning_args.lora_rank, diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 0f886c37..c7ffb675 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -109,10 +109,6 @@ def load_model( if not is_trainable: model.requires_grad_(False) - if not getattr(model, "quantization_method", None): - for param in filter(lambda p: p.device.type == "cuda", model.parameters()): - param.data = param.data.to(model_args.compute_dtype) - model.eval() else: model.train() diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 0d8b9d79..a5788a7c 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.packages import is_flash_attn2_available from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl +from .utils import QuantizationMethod if TYPE_CHECKING: @@ -173,10 +174,10 @@ def _configure_quantization( quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") - if quant_method == "gptq": + if quant_method == QuantizationMethod.GPTQ: quantization_config["use_exllama"] = False # disable exllama - if quant_method == "aqlm": + if quant_method == QuantizationMethod.AQLM: require_version( "transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" ) @@ -205,7 +206,7 @@ def _configure_quantization( elif model_args.quantization_bit is not None: # bnb if is_deepspeed_zero3_enabled(): - raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 4a4ecf2e..5a437491 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,3 +1,4 @@ +from enum import Enum, unique from typing import TYPE_CHECKING, Dict, List import torch @@ -17,6 +18,18 @@ if TYPE_CHECKING: logger = get_logger(__name__) +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + + def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: r""" Finds all available modules to apply lora. @@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: quantization_method = getattr(model, "quantization_method", None) if quantization_method is None: linear_cls = torch.nn.Linear - elif quantization_method == "bitsandbytes": + elif quantization_method == QuantizationMethod.BITS_AND_BYTES: import bitsandbytes as bnb linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 2ad02a54..099edc14 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer from ...train.sft.metric import ComputeMetrics from ...train.sft.trainer import CustomSeq2SeqTrainer from ...train.utils import create_modelcard_and_push -from ..utils import create_custom_optimzer, create_lora_plus_optimizer +from ..utils import create_custom_optimzer if TYPE_CHECKING: @@ -51,8 +51,6 @@ def run_sft( # Initialize our Trainer optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) - if finetuning_args.lora_lr_ratio: - optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args) trainer = CustomSeq2SeqTrainer( model=model, args=training_args, diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index a1b7bec4..43b76bef 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra def export_model(args: Optional[Dict[str, Any]] = None): model_args, data_args, finetuning_args, _ = get_infer_args(args) + model_args.device_map = {"": "cpu"} + if model_args.export_dir is None: - raise ValueError("Please specify `export_dir`.") + raise ValueError("Please specify `export_dir` to save model.") if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: raise ValueError("Please merge adapters before quantizing the model.") @@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None): if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") - if getattr(model, "quantization_method", None): - model = model.to("cpu") - elif hasattr(model.config, "torch_dtype"): - model = model.to(getattr(model.config, "torch_dtype")).to("cpu") - else: - model = model.to(torch.float16).to("cpu") - setattr(model.config, "torch_dtype", torch.float16) + if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model + output_dtype = getattr(model.config, "torch_dtype", torch.float16) + model = model.to(output_dtype) + setattr(model.config, "torch_dtype", output_dtype) model.save_pretrained( save_directory=model_args.export_dir, diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index d6e11a9b..42294164 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -1,15 +1,17 @@ import math from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from transformers.trainer import Trainer + import torch -from torch import nn +from transformers import Trainer from transformers.optimization import get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.trainer_pt_utils import get_parameter_names from transformers.utils.versions import require_version from ..extras.logging import get_logger from ..extras.packages import is_galore_available from ..hparams import FinetuningArguments, ModelArguments -from ..model import load_model_and_tokenizer, load_valuehead_params +from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params if is_galore_available(): @@ -29,9 +31,10 @@ logger = get_logger(__name__) class DummyOptimizer(torch.optim.Optimizer): - def __init__(self, *args, **kwargs): + def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None: dummy_tensor = torch.randn(1, 1) - super().__init__([dummy_tensor], {"lr": 1e-3}) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": lr}) def zero_grad(self, set_to_none: bool = True) -> None: pass @@ -142,59 +145,33 @@ def create_reward_model( return reward_model -def create_custom_optimzer( +def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: + r""" + Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) + """ + decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + return decay_parameters + + +def _create_galore_optimizer( model: "PreTrainedModel", dataset: Union["Dataset", "IterableDataset"], training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", -) -> Optional["torch.optim.Optimizer"]: - if not finetuning_args.use_galore: - return None - +) -> "torch.optim.Optimizer": require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git") - galore_params: List[torch.nn.Parameter] = [] - galore_targets = finetuning_args.galore_target.split(",") + if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": + galore_targets = find_all_linear_modules(model) + + galore_params: List["torch.nn.Parameter"] = [] for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): for param in module.parameters(): if param.requires_grad and len(param.shape) > 1: galore_params.append(param) - id_galore_params = {id(param) for param in galore_params} - trainable_params = filter(lambda param: param.requires_grad, model.parameters()) - non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params] - - if training_args.optim == "adamw_torch": - optim_class = GaLoreAdamW - optim_kwargs = { - "lr": training_args.learning_rate, - "eps": training_args.adam_epsilon, - "betas": (training_args.adam_beta1, training_args.adam_beta2), - "weight_decay": training_args.weight_decay, - } - - elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: - optim_class = GaLoreAdamW8bit - optim_kwargs = { - "lr": training_args.learning_rate, - "eps": training_args.adam_epsilon, - "betas": (training_args.adam_beta1, training_args.adam_beta2), - "weight_decay": training_args.weight_decay, - "optim_bits": 8, - "is_paged": "paged" in training_args.optim, - } - - elif training_args.optim == "adafactor": - optim_class = GaLoreAdafactor - optim_kwargs = { - "lr": training_args.learning_rate, - "weight_decay": training_args.weight_decay, - } - - else: - raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) - galore_kwargs = { "rank": finetuning_args.galore_rank, "update_proj_gap": finetuning_args.galore_update_interval, @@ -202,6 +179,30 @@ def create_custom_optimzer( "proj_type": finetuning_args.galore_proj_type, } + id_galore_params = {id(param) for param in galore_params} + decay_params, nodecay_params = [], [] # they are non-galore parameters + trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + trainable_params.append(param) + if id(param) not in id_galore_params: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + + if training_args.optim == "adamw_torch": + optim_class = GaLoreAdamW + elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: + optim_class = GaLoreAdamW8bit + elif training_args.optim == "adafactor": + optim_class = GaLoreAdafactor + else: + raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) + if finetuning_args.galore_layerwise: if training_args.gradient_accumulation_steps != 1: raise ValueError("Per-layer GaLore does not support gradient accumulation.") @@ -213,15 +214,18 @@ def create_custom_optimzer( num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} - for param in non_galore_params: + for param in nodecay_params: param_groups = [dict(params=[param])] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in decay_params: + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) for param in galore_params: - param_groups = [dict(params=[param], **galore_kwargs)] + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {} - for param in non_galore_params + galore_params: + for param in trainable_params: scheduler_dict[param] = get_scheduler( training_args.lr_scheduler_type, optimizer=optimizer_dict[param], @@ -235,99 +239,72 @@ def create_custom_optimzer( optimizer_dict[param].zero_grad() scheduler_dict[param].step() - for param in non_galore_params + galore_params: + for param in trainable_params: param.register_post_accumulate_grad_hook(optimizer_hook) - optimizer = DummyOptimizer() + optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result else: - param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)] + param_groups = [ + dict(params=nodecay_params), + dict(params=decay_params, weight_decay=training_args.weight_decay), + dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), + ] optimizer = optim_class(param_groups, **optim_kwargs) logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") return optimizer -def optimizer_group_callback(model, lora_lr_ratio, **defaults): - "lora plus" - params = [] - names = set() +def _create_loraplus_optimizer( + model: "PreTrainedModel", + dataset: Union["Dataset", "IterableDataset"], + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + if finetuning_args.finetuning_type != "lora": + raise ValueError("You should use LoRA tuning to activate LoRA+.") + + loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio + decay_args = {"weight_decay": training_args.weight_decay} + + decay_param_names = _get_decay_parameter_names(model) + param_dict: Dict[str, List["torch.nn.Parameter"]] = { + "lora_a": [], + "lora_b": [], + "lora_b_nodecay": [], + "embedding": [], + } for name, param in model.named_parameters(): - if "default" in name and ('lora_B' in name or - 'lora_embedding_B' in name): - params.append(param) - names.add(name) - if params: - assert 'lr' in defaults - return names, { - 'params': params, - 'lr': defaults['lr'] * lora_lr_ratio, - } - return None, None - - -def create_lora_plus_optimizer( - model: "PreTrainedModel", - training_args: "Seq2SeqTrainingArguments", - finetuning_args: "FinetuningArguments", -) -> Optional["torch.optim.Optimizer"]: - if finetuning_args.lora_lr_ratio is None: - return None - all_param_names = set() - param_groups = [] - param_names, param_group = optimizer_group_callback( - model, lora_lr_ratio=finetuning_args.lora_lr_ratio, - lr=training_args.learning_rate, - weight_decay=training_args.weight_decay) - if param_names and all_param_names & param_names: - raise ValueError( - 'Cannot set one parameter to different param groups') - if param_names and param_group: - all_param_names.update(param_names) - param_groups.append(param_group) - - opt_model = model - decay_parameters = Trainer.get_decay_parameter_names(None, opt_model) - param_groups.extend([ - { - 'params': [ - p for n, p in opt_model.named_parameters() - if (n in decay_parameters and n not in all_param_names and p.requires_grad) - ], - 'weight_decay': - training_args.weight_decay, - }, - { - 'params': [ - p for n, p in opt_model.named_parameters() - if (n not in decay_parameters and n not in all_param_names and p.requires_grad) - ], - 'weight_decay': - 0.0, - }, - ]) - - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) - - optimizer = optimizer_cls(param_groups, **optimizer_kwargs) - - if optimizer_cls.__name__ == 'Adam8bit': - import bitsandbytes - - manager = bitsandbytes.optim.GlobalOptimManager.get_instance() - - skipped = 0 - for module in opt_model.modules(): - if isinstance(module, nn.Embedding): - skipped += sum({ - p.data_ptr(): p.numel() - for p in module.parameters() - }.values()) - logger.info( - f'skipped {module}: {skipped / 2 ** 20}M params') - manager.register_module_override( - module, 'weight', {'optim_bits': 32}) - logger.debug( - f'bitsandbytes: will optimize {module} in fp32') - logger.info(f'skipped: {skipped / 2 ** 20}M params') + if param.requires_grad: + if "lora_embedding_B" in name: + param_dict["embedding"].append(param) + elif "lora_B" in name or param.ndim == 1: + if name in decay_param_names: + param_dict["lora_b"].append(param) + else: + param_dict["lora_b_nodecay"].append(param) + else: + param_dict["lora_a"].append(param) + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=param_dict["lora_a"], **decay_args), + dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), + dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), + dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), + ] + optimizer = optim_class(param_groups, **optim_kwargs) return optimizer + + +def create_custom_optimzer( + model: "PreTrainedModel", + dataset: Union["Dataset", "IterableDataset"], + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> Optional["torch.optim.Optimizer"]: + if not finetuning_args.use_galore: + return _create_galore_optimizer(model, dataset, training_args, finetuning_args) + + if finetuning_args.loraplus_lr_ratio is not None: + return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)