update dpo, kto trainer

This commit is contained in:
hiyouga 2024-05-29 00:14:29 +08:00
parent 900e1ea622
commit 7c8e01bb74
2 changed files with 12 additions and 10 deletions

View File

@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
@ -179,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
r"""
Computes log probabilities of the reference model.
@ -188,8 +188,8 @@ class CustomDPOTrainer(DPOTrainer):
return None, None
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
@ -221,7 +221,7 @@ class CustomDPOTrainer(DPOTrainer):
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch)
reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch)
losses, chosen_rewards, rejected_rewards = self.compute_preference_loss(
policy_chosen_logps,
policy_rejected_logps,

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer.utils import disable_dropout_in_model
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
@ -150,14 +150,14 @@ class CustomKTOTrainer(KTOTrainer):
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
@ -190,7 +190,9 @@ class CustomKTOTrainer(KTOTrainer):
policy_kl_logps,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch
)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,