From 7c8e01bb74bb2d2da5dba5059a9c262e4730b802 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 29 May 2024 00:14:29 +0800 Subject: [PATCH] update dpo, kto trainer --- src/llamafactory/train/dpo/trainer.py | 10 +++++----- src/llamafactory/train/kto/trainer.py | 12 +++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index f3c2443c..542335a3 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from transformers import Trainer from trl import DPOTrainer -from trl.trainer.utils import disable_dropout_in_model +from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..utils import create_custom_optimzer, create_custom_scheduler @@ -179,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer): return chosen_logps, rejected_logps, chosen_logits, rejected_logits def compute_reference_log_probs( - self, batch: Dict[str, "torch.Tensor"] + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: r""" Computes log probabilities of the reference model. @@ -188,8 +188,8 @@ class CustomDPOTrainer(DPOTrainer): return None, None if self.ref_model is None: - ref_model = self.model - ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() @@ -221,7 +221,7 @@ class CustomDPOTrainer(DPOTrainer): policy_rejected_logits, ) = self.concatenated_forward(model, batch) - reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(batch) + reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) losses, chosen_rewards, rejected_rewards = self.compute_preference_loss( policy_chosen_logps, policy_rejected_logps, diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 292e61c7..82ae722d 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch from transformers import Trainer from trl import KTOTrainer -from trl.trainer.utils import disable_dropout_in_model +from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX from ..utils import create_custom_optimzer, create_custom_scheduler @@ -150,14 +150,14 @@ class CustomKTOTrainer(KTOTrainer): return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps def compute_reference_log_probs( - self, batch: Dict[str, "torch.Tensor"] + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: r""" Computes log probabilities of the reference model. """ if self.ref_model is None: - ref_model = self.model - ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() + ref_model = model + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() @@ -190,7 +190,9 @@ class CustomKTOTrainer(KTOTrainer): policy_kl_logps, ) = self.concatenated_forward(model, batch) - reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch) + reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( + model, batch + ) losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( policy_chosen_logps, policy_rejected_logps,