Update utils.py

This commit is contained in:
hoshi-hiyouga 2024-04-16 17:30:12 +08:00 committed by GitHub
parent 6700a1b9fa
commit c9828f4c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 64 additions and 47 deletions

View File

@ -162,6 +162,15 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
return decay_parameters return decay_parameters
def _get_embedding_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters in embedding.
"""
result = {name for name, _ in model.get_input_embeddings().named_parameters()}
result.update(name for name, _ in model.get_output_embeddings().named_parameters())
return result
def _create_galore_optimizer( def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
@ -236,7 +245,7 @@ def _create_galore_optimizer(
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else: else:
param_groups = [ param_groups = [
dict(params=nodecay_params), dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay), dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
] ]
@ -280,82 +289,90 @@ def _create_loraplus_optimizer(
param_groups = [ param_groups = [
dict(params=param_dict["lora_a"], **decay_args), dict(params=param_dict["lora_a"], **decay_args),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args), dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr), dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer return optimizer
def _create_badam_optimizer( def _create_badam_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer": ) -> "torch.optim.Optimizer":
decay_param_names = _get_decay_parameter_names(model)
from transformers.trainer_pt_utils import get_parameter_names if finetuning_args.badam_mode == "ratio": # filter out the embedding layers for ratio-wise badam
decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS))) decay_param_names = [name for name in decay_param_names if name not in _get_embedding_names(model)]
# filter out the embedding layers when using badam ratio mode
if finetuning_args.badam_mode == "ratio": decay_params, nodecay_params = [], []
decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general for name, param in model.named_parameters():
optimizer_grouped_parameters = [ if param.requires_grad:
{ if name in decay_param_names:
"params": [p for n, p in model.named_parameters() if n in decay_parameters], decay_params.append(param)
"weight_decay": training_args.weight_decay, else:
}, nodecay_params.append(param)
{
"params": [p for n, p in model.named_parameters() if n not in decay_parameters], optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
"weight_decay": 0.0, param_groups = [
}, dict(params=nodecay_params, weight_decay=0.0),
dict(params=decay_params, weight_decay=training_args.weight_decay),
] ]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
# create BlockOptimizer
if finetuning_args.badam_mode == "layer": if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer from badam import BlockOptimizer
base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
optimizer = BlockOptimizer(base_optimizer=base_optimizer, base_optimizer = optim_class(param_groups, **optim_kwargs)
named_parameters_list=list(model.named_parameters()), optimizer = BlockOptimizer(
block_prefix_list=None, base_optimizer=base_optimizer,
switch_block_every=finetuning_args.switch_block_every, named_parameters_list=list(model.named_parameters()),
start_block=finetuning_args.start_block, block_prefix_list=None,
switch_mode=finetuning_args.switch_mode, switch_block_every=finetuning_args.badam_switch_block_every,
verbose=finetuning_args.badam_verbose) start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode,
logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, " verbose=finetuning_args.badam_verbose,
f"switch block every {finetuning_args.switch_block_every} steps, " )
f"default start block is {finetuning_args.start_block}") logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
f"default start block is {finetuning_args.badam_start_block}"
)
elif finetuning_args.badam_mode == "ratio": elif finetuning_args.badam_mode == "ratio":
assert finetuning_args.badam_update_ratio > 0.
from badam import BlockOptimizerRatio from badam import BlockOptimizerRatio
optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters,
named_parameters_list=list(model.named_parameters()), assert finetuning_args.badam_update_ratio > 1e-6
update_ratio=finetuning_args.badam_update_ratio, optimizer = BlockOptimizerRatio(
mask_mode=finetuning_args.badam_mask_mode, param_groups=param_groups,
verbose=finetuning_args.badam_verbose, named_parameters_list=list(model.named_parameters()),
**optimizer_kwargs) update_ratio=finetuning_args.badam_update_ratio,
mask_mode=finetuning_args.badam_mask_mode,
logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, " verbose=finetuning_args.badam_verbose,
f"mask mode is {finetuning_args.badam_mask_mode}") **optim_kwargs,
)
logger.info(
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
f"mask mode is {finetuning_args.badam_mask_mode}"
)
return optimizer return optimizer
def create_custom_optimzer( def create_custom_optimzer(
model: "PreTrainedModel", model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_galore: if finetuning_args.use_galore:
return _create_galore_optimizer(model, training_args, finetuning_args) return _create_galore_optimizer(model, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, training_args, finetuning_args) return _create_loraplus_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",