This commit is contained in:
hiyouga 2024-02-21 21:55:14 +08:00
parent daa3185350
commit 3cc10a01a7
2 changed files with 3 additions and 2 deletions

View File

@ -90,7 +90,7 @@ class RLHFArguments:
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}, metadata={"help": "The beta parameter for the DPO loss."},
) )
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."}, metadata={"help": "The type of DPO loss to use."},
) )

View File

@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float, beta: float,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"], loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
ftx_gamma: float, ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.reference_free = False
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX