update dpo, kto trainer
This commit is contained in:
parent
900e1ea622
commit
7c8e01bb74
|
@ -7,7 +7,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
@ -179,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
|
||||
def compute_reference_log_probs(
|
||||
self, batch: Dict[str, "torch.Tensor"]
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
|
@ -188,8 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
return None, None
|
||||
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
@ -221,7 +221,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch)
|
||||
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
|
||||
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
|||
import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
@ -150,14 +150,14 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||
|
||||
def compute_reference_log_probs(
|
||||
self, batch: Dict[str, "torch.Tensor"]
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
@ -190,7 +190,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
policy_kl_logps,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
|
||||
model, batch
|
||||
)
|
||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
|
|
Loading…
Reference in New Issue