fix #2532
This commit is contained in:
parent
daa3185350
commit
3cc10a01a7
|
@ -90,7 +90,7 @@ class RLHFArguments:
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."},
|
metadata={"help": "The beta parameter for the DPO loss."},
|
||||||
)
|
)
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."},
|
metadata={"help": "The type of DPO loss to use."},
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||||
ftx_gamma: float,
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
|
@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
|
self.reference_free = False
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
|
Loading…
Reference in New Issue