clean kto trainer

This commit is contained in:
hiyouga 2024-05-28 21:43:26 +08:00
parent 1e80a3a638
commit 900e1ea622
1 changed files with 50 additions and 45 deletions

View File

@ -1,7 +1,7 @@
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from transformers import Trainer
@ -101,42 +101,39 @@ class CustomKTOTrainer(KTOTrainer):
return -all_logps
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
with torch.no_grad():
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
if "pixel_values" in batch:
kl_model_inputs["pixel_values"] = batch["pixel_values"]
if "kl_token_type_ids" in batch:
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Runs forward pass and computes the log probabilities.
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
model_inputs = {
"input_ids": batch["{}input_ids".format(prefix)],
"attention_mask": batch["{}attention_mask".format(prefix)],
}
if "pixel_values" in batch:
model_inputs["pixel_values"] = batch["pixel_values"]
if "token_type_ids" in batch:
model_inputs["token_type_ids"] = batch["token_type_ids"]
if "{}token_type_ids".format(prefix) in batch:
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
target_logps = self.get_batch_logps(
logits=target_logits,
labels=batch["labels"],
logps = self.get_batch_logps(
logits=logits,
labels=batch["{}labels".format(prefix)],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
return logits, logps
kl_logps = self.get_batch_logps(
logits=kl_logits,
labels=batch["kl_labels"],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps = self.forward(model, batch)
with torch.no_grad():
_, kl_logps = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.")
@ -152,6 +149,30 @@ class CustomKTOTrainer(KTOTrainer):
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def compute_reference_log_probs(
self, batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r"""
Computes log probabilities of the reference model.
"""
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with torch.no_grad(), ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_kl_logps,
) = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",
@ -167,25 +188,9 @@ class CustomKTOTrainer(KTOTrainer):
policy_chosen_logits,
_,
policy_kl_logps,
) = self.forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
ref_model = self.model
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()
with ref_context:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_kl_logps,
) = self.forward(ref_model, batch)
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,