Update utils.py
This commit is contained in:
parent
6700a1b9fa
commit
c9828f4c6e
|
@ -162,6 +162,15 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||||
return decay_parameters
|
return decay_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def _get_embedding_names(model: "PreTrainedModel") -> List[str]:
|
||||||
|
r"""
|
||||||
|
Returns a list of names of parameters in embedding.
|
||||||
|
"""
|
||||||
|
result = {name for name, _ in model.get_input_embeddings().named_parameters()}
|
||||||
|
result.update(name for name, _ in model.get_output_embeddings().named_parameters())
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _create_galore_optimizer(
|
def _create_galore_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
@ -236,7 +245,7 @@ def _create_galore_optimizer(
|
||||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||||
else:
|
else:
|
||||||
param_groups = [
|
param_groups = [
|
||||||
dict(params=nodecay_params),
|
dict(params=nodecay_params, weight_decay=0.0),
|
||||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||||
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
|
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
|
||||||
]
|
]
|
||||||
|
@ -280,82 +289,90 @@ def _create_loraplus_optimizer(
|
||||||
param_groups = [
|
param_groups = [
|
||||||
dict(params=param_dict["lora_a"], **decay_args),
|
dict(params=param_dict["lora_a"], **decay_args),
|
||||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
||||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
|
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
|
||||||
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
||||||
]
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def _create_badam_optimizer(
|
def _create_badam_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
|
decay_param_names = _get_decay_parameter_names(model)
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
if finetuning_args.badam_mode == "ratio": # filter out the embedding layers for ratio-wise badam
|
||||||
decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS)))
|
decay_param_names = [name for name in decay_param_names if name not in _get_embedding_names(model)]
|
||||||
# filter out the embedding layers when using badam ratio mode
|
|
||||||
if finetuning_args.badam_mode == "ratio":
|
decay_params, nodecay_params = [], []
|
||||||
decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general
|
for name, param in model.named_parameters():
|
||||||
optimizer_grouped_parameters = [
|
if param.requires_grad:
|
||||||
{
|
if name in decay_param_names:
|
||||||
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
decay_params.append(param)
|
||||||
"weight_decay": training_args.weight_decay,
|
else:
|
||||||
},
|
nodecay_params.append(param)
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
"weight_decay": 0.0,
|
param_groups = [
|
||||||
},
|
dict(params=nodecay_params, weight_decay=0.0),
|
||||||
|
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||||
]
|
]
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
|
||||||
|
|
||||||
# create BlockOptimizer
|
|
||||||
if finetuning_args.badam_mode == "layer":
|
if finetuning_args.badam_mode == "layer":
|
||||||
from badam import BlockOptimizer
|
from badam import BlockOptimizer
|
||||||
base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
optimizer = BlockOptimizer(base_optimizer=base_optimizer,
|
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
named_parameters_list=list(model.named_parameters()),
|
optimizer = BlockOptimizer(
|
||||||
block_prefix_list=None,
|
base_optimizer=base_optimizer,
|
||||||
switch_block_every=finetuning_args.switch_block_every,
|
named_parameters_list=list(model.named_parameters()),
|
||||||
start_block=finetuning_args.start_block,
|
block_prefix_list=None,
|
||||||
switch_mode=finetuning_args.switch_mode,
|
switch_block_every=finetuning_args.badam_switch_block_every,
|
||||||
verbose=finetuning_args.badam_verbose)
|
start_block=finetuning_args.badam_start_block,
|
||||||
|
switch_mode=finetuning_args.badam_switch_mode,
|
||||||
logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, "
|
verbose=finetuning_args.badam_verbose,
|
||||||
f"switch block every {finetuning_args.switch_block_every} steps, "
|
)
|
||||||
f"default start block is {finetuning_args.start_block}")
|
logger.info(
|
||||||
|
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||||
|
f"switch block every {finetuning_args.badam_switch_block_every} steps, "
|
||||||
|
f"default start block is {finetuning_args.badam_start_block}"
|
||||||
|
)
|
||||||
|
|
||||||
elif finetuning_args.badam_mode == "ratio":
|
elif finetuning_args.badam_mode == "ratio":
|
||||||
assert finetuning_args.badam_update_ratio > 0.
|
|
||||||
from badam import BlockOptimizerRatio
|
from badam import BlockOptimizerRatio
|
||||||
optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters,
|
|
||||||
named_parameters_list=list(model.named_parameters()),
|
assert finetuning_args.badam_update_ratio > 1e-6
|
||||||
update_ratio=finetuning_args.badam_update_ratio,
|
optimizer = BlockOptimizerRatio(
|
||||||
mask_mode=finetuning_args.badam_mask_mode,
|
param_groups=param_groups,
|
||||||
verbose=finetuning_args.badam_verbose,
|
named_parameters_list=list(model.named_parameters()),
|
||||||
**optimizer_kwargs)
|
update_ratio=finetuning_args.badam_update_ratio,
|
||||||
|
mask_mode=finetuning_args.badam_mask_mode,
|
||||||
logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, "
|
verbose=finetuning_args.badam_verbose,
|
||||||
f"mask mode is {finetuning_args.badam_mask_mode}")
|
**optim_kwargs,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||||
|
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||||
|
)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def create_custom_optimzer(
|
def create_custom_optimzer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> Optional["torch.optim.Optimizer"]:
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
if finetuning_args.use_badam:
|
|
||||||
return _create_badam_optimizer(model, training_args, finetuning_args)
|
|
||||||
|
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
if finetuning_args.loraplus_lr_ratio is not None:
|
if finetuning_args.loraplus_lr_ratio is not None:
|
||||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
|
if finetuning_args.use_badam:
|
||||||
|
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
|
|
||||||
def create_custom_scheduler(
|
def create_custom_scheduler(
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
|
Loading…
Reference in New Issue