From 06c8908d3fe48907ddb585c5fa15677fc5416f94 Mon Sep 17 00:00:00 2001 From: Jonery Date: Mon, 15 Apr 2024 23:15:27 +0800 Subject: [PATCH] Feature BAdam --- examples/extras/badam/sft.sh | 36 ++++++++++++++++ requirements.txt | 1 + src/llmtuner/hparams/finetuning_args.py | 43 ++++++++++++++++++- src/llmtuner/hparams/parser.py | 6 +++ src/llmtuner/model/adapter.py | 6 +-- src/llmtuner/model/patcher.py | 5 ++- src/llmtuner/model/utils.py | 42 ++++++++++++++++++ src/llmtuner/train/sft/trainer.py | 6 ++- src/llmtuner/train/utils.py | 57 +++++++++++++++++++++++++ 9 files changed, 195 insertions(+), 7 deletions(-) create mode 100644 examples/extras/badam/sft.sh diff --git a/examples/extras/badam/sft.sh b/examples/extras/badam/sft.sh new file mode 100644 index 00000000..daa63913 --- /dev/null +++ b/examples/extras/badam/sft.sh @@ -0,0 +1,36 @@ +# BAdam layer-wise +export CUDA_VISIBLE_DEVICES=0 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +python ../../../src/train_bash.py \ +--stage sft \ +--do_train \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--dataset alpaca_gpt4_en,glaive_toolcall \ +--dataset_dir ../../../data \ +--template default \ +--finetuning_type full \ +--output_dir ../../../saves/LLaMA2-7B/badam \ +--overwrite_cache \ +--overwrite_output_dir \ +--cutoff_len 1024 \ +--preprocessing_num_workers 32 \ +--per_device_train_batch_size 8 \ +--per_device_eval_batch_size 5 \ +--gradient_accumulation_steps 2 \ +--lr_scheduler_type cosine \ +--logging_steps 10 \ +--warmup_steps 20 \ +--save_steps 100 \ +--eval_steps 100 \ +--evaluation_strategy steps \ +--load_best_model_at_end \ +--learning_rate 5e-5 \ +--num_train_epochs 3.0 \ +--val_size 0.1 \ +--plot_loss \ +--use_badam \ +--switch_mode descending \ +--badam_verbose 2 \ +--switch_block_every 50 \ +--pure_bf16 \ + diff --git a/requirements.txt b/requirements.txt index 1fa5a142..9d58d75a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ fastapi sse-starlette matplotlib fire +badam \ No newline at end of file diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 177a9f8a..d64f1583 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -163,6 +163,47 @@ class RLHFArguments: metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, ) +@dataclass +class BAdamArgument: + r""" + Arguments for BAdam optimizer. + """ + use_badam: bool = field( + default=False, + metadata={"help": "Whether or not to use BAdam optimizer."}, + ) + badam_mode: Literal["layer", "ratio"] = field( + default="layer", + metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."}, + ) + + # ======== Arguments for layer-wise update ======== + start_block: Optional[int] = field( + default=None, + metadata={"help": "The starting block index for block-wise fine-tuning."} + ) + switch_block_every: Optional[int] = field( + default=50, + metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."} + ) + switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + default="ascending", + metadata={"help": "the strategy of picking block to update."} + ) + + # ======== Arguments for ratio-wise update ======== + badam_update_ratio: float = field( + default=0., + metadata={"help": "The ratio of the update for the BAdam optimizer."} + ) + badam_mask_mode: Literal["adjacent", "scatter"] = field( + default="adjacent", + metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."} + ) + badam_verbose: int = field( + default=0, + metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"} + ) @dataclass class GaloreArguments: @@ -204,7 +245,7 @@ class GaloreArguments: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 8f3bd18a..032a1a4b 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -171,6 +171,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if finetuning_args.use_galore and training_args.deepspeed is not None: raise ValueError("GaLore is incompatible with DeepSpeed.") + if (finetuning_args.use_badam + and finetuning_args.badam_mode == "layer" + and training_args.parallel_mode.value == "distributed" + ): + raise ValueError("BAdam with layer-wise mode is not supported in distributed training by now, use ratio mode instead.") + if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 4bb4057d..71c9c2f4 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -37,7 +37,7 @@ def init_adapter( if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): model = model.float() if finetuning_args.finetuning_type == "freeze" and is_trainable: @@ -82,7 +82,7 @@ def init_adapter( for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers): - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): param.data = param.data.to(torch.float32) else: param.requires_grad_(False) @@ -162,7 +162,7 @@ def init_adapter( ) model = get_peft_model(model, lora_config) - if not finetuning_args.pure_bf16: + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): for param in filter(lambda p: p.requires_grad, model.parameters()): param.data = param.data.to(torch.float32) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index c48df995..563b1827 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -17,7 +17,7 @@ from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.packages import is_flash_attn2_available from ..extras.patches.llama_patch import apply_llama_patch -from .utils import QuantizationMethod, add_z3_leaf_module +from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable if TYPE_CHECKING: @@ -266,8 +266,9 @@ def _prepare_model_for_training( else: # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) - model.enable_input_require_grads() + # model.enable_input_require_grads() setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled logger.info("Gradient checkpointing enabled.") diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 771e6112..e83a903e 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -135,3 +135,45 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok model.__class__.register_for_auto_class() if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): tokenizer.__class__.register_for_auto_class() + +def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + + Activates gradient checkpointing for the current model. + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + # gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + def gradient_checkpointing_func(func, *args, **kwargs): + module = func.__self__ + + if any([p.requires_grad for p in module.parameters()]): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return checkpoint(func, *args, **kwargs) + + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() \ No newline at end of file diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 8d2f9fa0..d750f491 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -9,7 +9,8 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..utils import create_custom_optimzer, create_custom_scheduler - +from types import MethodType +from packaging import version if TYPE_CHECKING: from transformers.trainer import PredictionOutput @@ -28,6 +29,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: super().__init__(**kwargs) self.finetuning_args = finetuning_args + if version.parse(torch.__version__) >= version.parse("1.13"): + from badam import clip_grad_norm_for_sparse_tensor + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index d921aec4..65233f72 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -287,12 +287,69 @@ def _create_loraplus_optimizer( logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) return optimizer +def _create_badam_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + + from transformers.trainer_pt_utils import get_parameter_names + decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS))) + # filter out the embedding layers when using badam ratio mode + if finetuning_args.badam_mode == "ratio": + decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if n not in decay_parameters], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + + # create BlockOptimizer + if finetuning_args.badam_mode == "layer": + from badam import BlockOptimizer + base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + optimizer = BlockOptimizer(base_optimizer=base_optimizer, + named_parameters_list=list(model.named_parameters()), + block_prefix_list=None, + switch_block_every=finetuning_args.switch_block_every, + start_block=finetuning_args.start_block, + switch_mode=finetuning_args.switch_mode, + verbose=finetuning_args.badam_verbose) + + logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, " + f"switch block every {finetuning_args.switch_block_every} steps, " + f"default start block is {finetuning_args.start_block}") + + elif finetuning_args.badam_mode == "ratio": + assert finetuning_args.badam_update_ratio > 0. + from badam import BlockOptimizerRatio + optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters, + named_parameters_list=list(model.named_parameters()), + update_ratio=finetuning_args.badam_update_ratio, + mask_mode=finetuning_args.badam_mask_mode, + verbose=finetuning_args.badam_verbose, + **optimizer_kwargs) + + logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, " + f"mask mode is {finetuning_args.badam_mask_mode}") + + return optimizer def create_custom_optimzer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: + if finetuning_args.use_badam: + return _create_badam_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_galore: return _create_galore_optimizer(model, training_args, finetuning_args)