fix #2532
This commit is contained in:
parent
daa3185350
commit
3cc10a01a7
|
@ -90,7 +90,7 @@ class RLHFArguments:
|
|||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
)
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
|
@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
|
|
Loading…
Reference in New Issue