diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 9c07df66..e9ba896c 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -106,7 +106,7 @@ class CustomDPOTrainer(DPOTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 460311e4..deb3fce2 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -101,7 +101,7 @@ class CustomKTOTrainer(KTOTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 0d55bce5..58ea83d8 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -39,7 +39,7 @@ from trl.models.utils import unwrap_model_for_generation from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm @@ -303,7 +303,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimizer(model, training_args, finetuning_args) if optimizer is None: decay_params, nodecay_params = [], [] decay_param_names = self.get_decay_parameter_names(model) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index e8f180a6..0c457b97 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,7 +19,7 @@ from transformers import Trainer from ...extras.logging import get_logger from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -57,7 +57,7 @@ class CustomTrainer(Trainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 63f925bb..45d9e26b 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -25,7 +25,7 @@ from transformers import Trainer from ...extras.logging import get_logger from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -65,7 +65,7 @@ class PairwiseTrainer(Trainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 954bb69f..e4958aa2 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,7 +27,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..callbacks import PissaConvertCallback, SaveProcessorCallback -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: @@ -66,7 +66,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) return super().create_optimizer() def create_scheduler( diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index ffec4776..c5688665 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -366,7 +366,7 @@ def _create_badam_optimizer( return optimizer -def create_custom_optimzer( +def create_custom_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments",