This commit is contained in:
hiyouga 2024-03-13 23:43:42 +08:00
parent 72367307df
commit 714d936dfb
1 changed files with 2 additions and 1 deletions

View File

@ -294,6 +294,7 @@ def _create_loraplus_optimizer(
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args), dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
] ]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
return optimizer return optimizer
@ -303,7 +304,7 @@ def create_custom_optimzer(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore: if finetuning_args.use_galore:
return _create_galore_optimizer(model, dataset, training_args, finetuning_args) return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None: if finetuning_args.loraplus_lr_ratio is not None: