feat: add support for adammini

This commit is contained in:
moontidef 2024-08-07 10:08:22 +08:00
parent 40908a36fa
commit 82bc15dc79
2 changed files with 33 additions and 0 deletions

View File

@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
) )
use_adammini: bool = field(
default=False,
metadata={"help": "Whether or not to use AdamMini optimizer."},
)
def __post_init__(self): def __post_init__(self):
def split_arg(arg): def split_arg(arg):

View File

@ -365,6 +365,32 @@ def _create_badam_optimizer(
return optimizer return optimizer
def _create_adammini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
n_embd = model.config.hidden_size
n_head = model.config.num_attention_heads
n_query_groups = getattr(model.config, "num_key_value_heads", n_head)
print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups)
optimizer = Adam_mini(
named_parameters=model.named_parameters(),
lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
model_sharding=False,
dim=n_embd,
n_heads=n_head,
n_kv_heads=n_query_groups,
)
return optimizer
def create_custom_optimizer( def create_custom_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
@ -380,6 +406,9 @@ def create_custom_optimizer(
if finetuning_args.use_badam: if finetuning_args.use_badam:
return _create_badam_optimizer(model, training_args, finetuning_args) return _create_badam_optimizer(model, training_args, finetuning_args)
if finetuning_args.use_adammini:
return _create_adammini_optimizer(model, training_args, finetuning_args)
def create_custom_scheduler( def create_custom_scheduler(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",