From 3cc10a01a792a92b99b952a45bb21c25097fccf6 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 21 Feb 2024 21:55:14 +0800 Subject: [PATCH] fix #2532 --- src/llmtuner/hparams/finetuning_args.py | 2 +- src/llmtuner/train/dpo/trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 6e5cdb3f..88e4d65c 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -90,7 +90,7 @@ class RLHFArguments: default=0.1, metadata={"help": "The beta parameter for the DPO loss."}, ) - dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( + dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field( default="sigmoid", metadata={"help": "The type of DPO loss to use."}, ) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index 2ea4707c..eb989b19 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer): def __init__( self, beta: float, - loss_type: Literal["sigmoid", "hinge", "ipo", "kto"], + loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"], ftx_gamma: float, model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, @@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer): if ref_model is not None: disable_dropout_in_model(ref_model) + self.reference_free = False self.use_dpo_data_collator = True # hack to avoid warning self.generate_during_eval = False # disable at evaluation self.label_pad_token_id = IGNORE_INDEX