diff --git a/examples/extras/badam/sft.sh b/examples/extras/badam/sft.sh new file mode 100644 index 00000000..c2319caa --- /dev/null +++ b/examples/extras/badam/sft.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ + --stage sft \ + --do_train \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --dataset alpaca_gpt4_en,glaive_toolcall \ + --dataset_dir ../../../data \ + --template default \ + --finetuning_type full \ + --use_badam \ + --badam_switch_mode descending \ + --badam_switch_block_every 50 \ + --badam_verbose 2 \ + --output_dir ../../../saves/LLaMA2-7B/badam/sft \ + --overwrite_cache \ + --overwrite_output_dir \ + --cutoff_len 1024 \ + --preprocessing_num_workers 16 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --warmup_steps 20 \ + --save_steps 100 \ + --eval_steps 100 \ + --evaluation_strategy steps \ + --load_best_model_at_end \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --max_samples 3000 \ + --val_size 0.1 \ + --plot_loss \ + --pure_bf16 diff --git a/setup.py b/setup.py index fd5bdf7e..9ef881e2 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ extra_require = { "metrics": ["nltk", "jieba", "rouge-chinese"], "unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"], "galore": ["galore-torch"], + "badam": ["badam"], "vllm": ["vllm>=0.3.3"], "bitsandbytes": ["bitsandbytes>=0.39.0"], "gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"], diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 177a9f8a..899c7284 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -172,7 +172,7 @@ class GaloreArguments: use_galore: bool = field( default=False, - metadata={"help": "Whether or not to use gradient low-Rank projection."}, + metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."}, ) galore_target: str = field( default="all", @@ -204,7 +204,54 @@ class GaloreArguments: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): +class BAdamArgument: + r""" + Arguments pertaining to the BAdam optimizer. + """ + + use_badam: bool = field( + default=False, + metadata={"help": "Whether or not to use the BAdam optimizer."}, + ) + badam_mode: Literal["layer", "ratio"] = field( + default="layer", + metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, + ) + badam_start_block: Optional[int] = field( + default=None, + metadata={"help": "The starting block index for layer-wise BAdam."}, + ) + badam_switch_block_every: Optional[int] = field( + default=50, + metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."}, + ) + badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + default="ascending", + metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, + ) + badam_update_ratio: float = field( + default=0.0, + metadata={"help": "The ratio of the update for ratio-wise BAdam."}, + ) + badam_mask_mode: Literal["adjacent", "scatter"] = field( + default="adjacent", + metadata={ + "help": """The mode of the mask for BAdam optimizer. \ + `adjacent` means that the trainable parameters are adjacent to each other, \ + `scatter` means that trainable parameters are randomly choosed from the weight.""" + }, + ) + badam_verbose: int = field( + default=0, + metadata={ + "help": """The verbosity level of BAdam optimizer. \ + 0 for no print, 1 for print the block prefix, 2 for print trainable parameters""" + }, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 84712b3b..a9f8ffd7 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -88,6 +88,9 @@ def _check_extra_dependencies( if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch") + if finetuning_args.use_badam: + require_version("badam", "To fix: pip install badam") + if training_args is not None and training_args.predict_with_generate: require_version("jieba", "To fix: pip install jieba") require_version("nltk", "To fix: pip install nltk") @@ -172,7 +175,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: raise ValueError("Distributed training does not support layer-wise GaLore.") if finetuning_args.use_galore and training_args.deepspeed is not None: - raise ValueError("GaLore is incompatible with DeepSpeed.") + raise ValueError("GaLore is incompatible with DeepSpeed yet.") + + if ( + finetuning_args.use_badam + and finetuning_args.badam_mode == "layer" + and training_args.parallel_mode.value == "distributed" + ): + raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index cabb21df..b712bdcf 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -37,7 +37,7 @@ def init_adapter( if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): model = model.float() if finetuning_args.finetuning_type == "freeze" and is_trainable: @@ -82,7 +82,7 @@ def init_adapter( for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers): - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): param.data = param.data.to(torch.float32) else: param.requires_grad_(False) @@ -166,7 +166,7 @@ def init_adapter( ) model = get_peft_model(model, lora_config) - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.float32) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 49c9f598..fb2835e8 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -17,7 +17,7 @@ from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.packages import is_flash_attn2_available from ..extras.patches.llama_patch import apply_llama_patch -from .utils import QuantizationMethod, add_z3_leaf_module +from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable if TYPE_CHECKING: @@ -268,8 +268,8 @@ def _prepare_model_for_training( else: # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - model.enable_input_require_grads() setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 771e6112..7e4430d1 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -1,5 +1,6 @@ from enum import Enum, unique -from typing import TYPE_CHECKING, Dict, List +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from transformers import PreTrainedModel @@ -100,6 +101,37 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n return module_names +def gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 8d2f9fa0..def427fd 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -1,5 +1,6 @@ import json import os +from types import MethodType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -28,6 +29,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args + if finetuning_args.use_badam: + from badam import clip_grad_norm_for_sparse_tensor + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index d921aec4..2835eddf 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -162,6 +162,15 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: return decay_parameters +def _get_embedding_names(model: "PreTrainedModel") -> List[str]: + r""" + Returns a list of names of parameters in embedding. + """ + result = {name for name, _ in model.get_input_embeddings().named_parameters()} + result.update(name for name, _ in model.get_output_embeddings().named_parameters()) + return result + + def _create_galore_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", @@ -236,7 +245,7 @@ def _create_galore_optimizer( optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) else: param_groups = [ - dict(params=nodecay_params), + dict(params=nodecay_params, weight_decay=0.0), dict(params=decay_params, weight_decay=training_args.weight_decay), dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), ] @@ -280,7 +289,7 @@ def _create_loraplus_optimizer( param_groups = [ dict(params=param_dict["lora_a"], **decay_args), dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), - dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), + dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0), dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), ] optimizer = optim_class(param_groups, **optim_kwargs) @@ -288,6 +297,68 @@ def _create_loraplus_optimizer( return optimizer +def _create_badam_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + decay_param_names = _get_decay_parameter_names(model) + if finetuning_args.badam_mode == "ratio": # filter out the embedding layers for ratio-wise badam + decay_param_names = [name for name in decay_param_names if name not in _get_embedding_names(model)] + + decay_params, nodecay_params = [], [] + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + ] + + if finetuning_args.badam_mode == "layer": + from badam import BlockOptimizer + + base_optimizer = optim_class(param_groups, **optim_kwargs) + optimizer = BlockOptimizer( + base_optimizer=base_optimizer, + named_parameters_list=list(model.named_parameters()), + block_prefix_list=None, + switch_block_every=finetuning_args.badam_switch_block_every, + start_block=finetuning_args.badam_start_block, + switch_mode=finetuning_args.badam_switch_mode, + verbose=finetuning_args.badam_verbose, + ) + logger.info( + f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " + f"switch block every {finetuning_args.badam_switch_block_every} steps, " + f"default start block is {finetuning_args.badam_start_block}" + ) + + elif finetuning_args.badam_mode == "ratio": + from badam import BlockOptimizerRatio + + assert finetuning_args.badam_update_ratio > 1e-6 + optimizer = BlockOptimizerRatio( + param_groups=param_groups, + named_parameters_list=list(model.named_parameters()), + update_ratio=finetuning_args.badam_update_ratio, + mask_mode=finetuning_args.badam_mask_mode, + verbose=finetuning_args.badam_verbose, + **optim_kwargs, + ) + logger.info( + f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"mask mode is {finetuning_args.badam_mask_mode}" + ) + + return optimizer + + def create_custom_optimzer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", @@ -299,6 +370,9 @@ def create_custom_optimzer( if finetuning_args.loraplus_lr_ratio is not None: return _create_loraplus_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_badam: + return _create_badam_optimizer(model, training_args, finetuning_args) + def create_custom_scheduler( training_args: "Seq2SeqTrainingArguments",