diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 096fd935..292e61c7 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -1,7 +1,7 @@ from collections import defaultdict from contextlib import nullcontext from types import MethodType -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch from transformers import Trainer @@ -101,42 +101,39 @@ class CustomKTOTrainer(KTOTrainer): return -all_logps def forward( - self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: - with torch.no_grad(): - kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]} - if "pixel_values" in batch: - kl_model_inputs["pixel_values"] = batch["pixel_values"] - - if "kl_token_type_ids" in batch: - kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"] - - kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) - - model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" + ) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Runs forward pass and computes the log probabilities. + """ + batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + model_inputs = { + "input_ids": batch["{}input_ids".format(prefix)], + "attention_mask": batch["{}attention_mask".format(prefix)], + } if "pixel_values" in batch: model_inputs["pixel_values"] = batch["pixel_values"] - if "token_type_ids" in batch: - model_inputs["token_type_ids"] = batch["token_type_ids"] + if "{}token_type_ids".format(prefix) in batch: + model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)] - target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) + logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) - target_logps = self.get_batch_logps( - logits=target_logits, - labels=batch["labels"], + logps = self.get_batch_logps( + logits=logits, + labels=batch["{}labels".format(prefix)], average_log_prob=False, is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) + return logits, logps - kl_logps = self.get_batch_logps( - logits=kl_logits, - labels=batch["kl_labels"], - average_log_prob=False, - is_encoder_decoder=self.is_encoder_decoder, - label_pad_token_id=self.label_pad_token_id, - ) + def concatenated_forward( + self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + target_logits, target_logps = self.forward(model, batch) + with torch.no_grad(): + _, kl_logps = self.forward(model, batch, prefix="kl_") if len(target_logps) != len(batch["kto_tags"]): raise ValueError("Mismatched shape of inputs and labels.") @@ -152,6 +149,30 @@ class CustomKTOTrainer(KTOTrainer): return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps + def compute_reference_log_probs( + self, batch: Dict[str, "torch.Tensor"] + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: + r""" + Computes log probabilities of the reference model. + """ + if self.ref_model is None: + ref_model = self.model + ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() + else: + ref_model = self.ref_model + ref_context = nullcontext() + + with torch.no_grad(), ref_context: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_kl_logps, + ) = self.concatenated_forward(ref_model, batch) + + return reference_chosen_logps, reference_rejected_logps, reference_kl_logps + def get_batch_loss_metrics( self, model: "PreTrainedModel", @@ -167,25 +188,9 @@ class CustomKTOTrainer(KTOTrainer): policy_chosen_logits, _, policy_kl_logps, - ) = self.forward(model, batch) - - with torch.no_grad(): - if self.ref_model is None: - ref_model = self.model - ref_context = self.accelerator.unwrap_model(self.model).disable_adapter() - else: - ref_model = self.ref_model - ref_context = nullcontext() - - with ref_context: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_kl_logps, - ) = self.forward(ref_model, batch) + ) = self.concatenated_forward(model, batch) + reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch) losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( policy_chosen_logps, policy_rejected_logps,