From 82bc15dc795f95768b81c25eaaabdc613da30cd8 Mon Sep 17 00:00:00 2001 From: moontidef <53668275+relic-yuexi@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:08:22 +0800 Subject: [PATCH] feat: add support for adammini --- src/llamafactory/hparams/finetuning_args.py | 4 +++ src/llamafactory/train/trainer_utils.py | 29 +++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 0ea9003c..0edae2d4 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) + use_adammini: bool = field( + default=False, + metadata={"help": "Whether or not to use AdamMini optimizer."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index c5688665..6503cc42 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -365,6 +365,32 @@ def _create_badam_optimizer( return optimizer +def _create_adammini_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + from adam_mini import Adam_mini + + n_embd = model.config.hidden_size + n_head = model.config.num_attention_heads + n_query_groups = getattr(model.config, "num_key_value_heads", n_head) + + print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups) + + optimizer = Adam_mini( + named_parameters=model.named_parameters(), + lr=training_args.learning_rate, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + model_sharding=False, + dim=n_embd, + n_heads=n_head, + n_kv_heads=n_query_groups, + ) + + return optimizer def create_custom_optimizer( model: "PreTrainedModel", @@ -380,6 +406,9 @@ def create_custom_optimizer( if finetuning_args.use_badam: return _create_badam_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_adammini: + return _create_adammini_optimizer(model, training_args, finetuning_args) + def create_custom_scheduler( training_args: "Seq2SeqTrainingArguments",