diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 0ea9003c..0edae2d4 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -342,6 +342,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) + use_adammini: bool = field( + default=False, + metadata={"help": "Whether or not to use AdamMini optimizer."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 9c07df66..e9ba896c 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -106,7 +106,7 @@ class CustomDPOTrainer(DPOTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 460311e4..deb3fce2 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -101,7 +101,7 @@ class CustomKTOTrainer(KTOTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 0d55bce5..58ea83d8 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -39,7 +39,7 @@ from trl.models.utils import unwrap_model_for_generation from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -303,7 +303,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimizer(model, training_args, finetuning_args) if optimizer is None: decay_params, nodecay_params = [], [] decay_param_names = self.get_decay_parameter_names(model) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index e8f180a6..0c457b97 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,7 +19,7 @@ from transformers import Trainer from ...extras.logging import get_logger from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -57,7 +57,7 @@ class CustomTrainer(Trainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 63f925bb..45d9e26b 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -25,7 +25,7 @@ from transformers import Trainer from ...extras.logging import get_logger from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -65,7 +65,7 @@ class PairwiseTrainer(Trainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 954bb69f..e4958aa2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,7 +27,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -66,7 +66,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index ffec4776..6503cc42 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -365,8 +365,34 @@ def _create_badam_optimizer( return optimizer +def _create_adammini_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + from adam_mini import Adam_mini -def create_custom_optimzer( + n_embd = model.config.hidden_size + n_head = model.config.num_attention_heads + n_query_groups = getattr(model.config, "num_key_value_heads", n_head) + + print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups) + + optimizer = Adam_mini( + named_parameters=model.named_parameters(), + lr=training_args.learning_rate, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + model_sharding=False, + dim=n_embd, + n_heads=n_head, + n_kv_heads=n_query_groups, + ) + + return optimizer + +def create_custom_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", @@ -380,6 +406,9 @@ def create_custom_optimzer( if finetuning_args.use_badam: return _create_badam_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_adammini: + return _create_adammini_optimizer(model, training_args, finetuning_args) + def create_custom_scheduler( training_args: "Seq2SeqTrainingArguments",