From c9828f4c6e6c150c884e02d0213dff0c09801e77 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 16 Apr 2024 17:30:12 +0800 Subject: [PATCH] Update utils.py --- src/llmtuner/train/utils.py | 111 +++++++++++++++++++++--------------- 1 file changed, 64 insertions(+), 47 deletions(-) diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 65233f72..2835eddf 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -162,6 +162,15 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: return decay_parameters +def _get_embedding_names(model: "PreTrainedModel") -> List[str]: + r""" + Returns a list of names of parameters in embedding. + """ + result = {name for name, _ in model.get_input_embeddings().named_parameters()} + result.update(name for name, _ in model.get_output_embeddings().named_parameters()) + return result + + def _create_galore_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", @@ -236,7 +245,7 @@ def _create_galore_optimizer( optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) else: param_groups = [ - dict(params=nodecay_params), + dict(params=nodecay_params, weight_decay=0.0), dict(params=decay_params, weight_decay=training_args.weight_decay), dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), ] @@ -280,82 +289,90 @@ def _create_loraplus_optimizer( param_groups = [ dict(params=param_dict["lora_a"], **decay_args), dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), - dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), + dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0), dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), ] optimizer = optim_class(param_groups, **optim_kwargs) logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) return optimizer + def _create_badam_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": - - from transformers.trainer_pt_utils import get_parameter_names - decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS))) - # filter out the embedding layers when using badam ratio mode - if finetuning_args.badam_mode == "ratio": - decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], - "weight_decay": training_args.weight_decay, - }, - { - "params": [p for n, p in model.named_parameters() if n not in decay_parameters], - "weight_decay": 0.0, - }, + decay_param_names = _get_decay_parameter_names(model) + if finetuning_args.badam_mode == "ratio": # filter out the embedding layers for ratio-wise badam + decay_param_names = [name for name in decay_param_names if name not in _get_embedding_names(model)] + + decay_params, nodecay_params = [], [] + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), ] - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) - - # create BlockOptimizer if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer - base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - optimizer = BlockOptimizer(base_optimizer=base_optimizer, - named_parameters_list=list(model.named_parameters()), - block_prefix_list=None, - switch_block_every=finetuning_args.switch_block_every, - start_block=finetuning_args.start_block, - switch_mode=finetuning_args.switch_mode, - verbose=finetuning_args.badam_verbose) - - logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, " - f"switch block every {finetuning_args.switch_block_every} steps, " - f"default start block is {finetuning_args.start_block}") - + + base_optimizer = optim_class(param_groups, **optim_kwargs) + optimizer = BlockOptimizer( + base_optimizer=base_optimizer, + named_parameters_list=list(model.named_parameters()), + block_prefix_list=None, + switch_block_every=finetuning_args.badam_switch_block_every, + start_block=finetuning_args.badam_start_block, + switch_mode=finetuning_args.badam_switch_mode, + verbose=finetuning_args.badam_verbose, + ) + logger.info( + f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " + f"switch block every {finetuning_args.badam_switch_block_every} steps, " + f"default start block is {finetuning_args.badam_start_block}" + ) + elif finetuning_args.badam_mode == "ratio": - assert finetuning_args.badam_update_ratio > 0. from badam import BlockOptimizerRatio - optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters, - named_parameters_list=list(model.named_parameters()), - update_ratio=finetuning_args.badam_update_ratio, - mask_mode=finetuning_args.badam_mask_mode, - verbose=finetuning_args.badam_verbose, - **optimizer_kwargs) - - logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, " - f"mask mode is {finetuning_args.badam_mask_mode}") - + + assert finetuning_args.badam_update_ratio > 1e-6 + optimizer = BlockOptimizerRatio( + param_groups=param_groups, + named_parameters_list=list(model.named_parameters()), + update_ratio=finetuning_args.badam_update_ratio, + mask_mode=finetuning_args.badam_mask_mode, + verbose=finetuning_args.badam_verbose, + **optim_kwargs, + ) + logger.info( + f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"mask mode is {finetuning_args.badam_mask_mode}" + ) + return optimizer + def create_custom_optimzer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: - if finetuning_args.use_badam: - return _create_badam_optimizer(model, training_args, finetuning_args) - if finetuning_args.use_galore: return _create_galore_optimizer(model, training_args, finetuning_args) if finetuning_args.loraplus_lr_ratio is not None: return _create_loraplus_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_badam: + return _create_badam_optimizer(model, training_args, finetuning_args) + def create_custom_scheduler( training_args: "Seq2SeqTrainingArguments",