forked from p04798526/LLaMA-Factory-Mirror
clean kto trainer
This commit is contained in:
parent
1e80a3a638
commit
900e1ea622
|
@ -1,7 +1,7 @@
|
|||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
@ -101,42 +101,39 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
return -all_logps
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
kl_model_inputs = {"input_ids": batch["kl_input_ids"], "attention_mask": batch["kl_attention_mask"]}
|
||||
if "pixel_values" in batch:
|
||||
kl_model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "kl_token_type_ids" in batch:
|
||||
kl_model_inputs["token_type_ids"] = batch["kl_token_type_ids"]
|
||||
|
||||
kl_logits = model(**kl_model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
model_inputs = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]}
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch["{}input_ids".format(prefix)],
|
||||
"attention_mask": batch["{}attention_mask".format(prefix)],
|
||||
}
|
||||
if "pixel_values" in batch:
|
||||
model_inputs["pixel_values"] = batch["pixel_values"]
|
||||
|
||||
if "token_type_ids" in batch:
|
||||
model_inputs["token_type_ids"] = batch["token_type_ids"]
|
||||
if "{}token_type_ids".format(prefix) in batch:
|
||||
model_inputs["token_type_ids"] = batch["{}token_type_ids".format(prefix)]
|
||||
|
||||
target_logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
|
||||
target_logps = self.get_batch_logps(
|
||||
logits=target_logits,
|
||||
labels=batch["labels"],
|
||||
logps = self.get_batch_logps(
|
||||
logits=logits,
|
||||
labels=batch["{}labels".format(prefix)],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
return logits, logps
|
||||
|
||||
kl_logps = self.get_batch_logps(
|
||||
logits=kl_logits,
|
||||
labels=batch["kl_labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
target_logits, target_logps = self.forward(model, batch)
|
||||
with torch.no_grad():
|
||||
_, kl_logps = self.forward(model, batch, prefix="kl_")
|
||||
|
||||
if len(target_logps) != len(batch["kto_tags"]):
|
||||
raise ValueError("Mismatched shape of inputs and labels.")
|
||||
|
@ -152,6 +149,30 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||
|
||||
def compute_reference_log_probs(
|
||||
self, batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
r"""
|
||||
Computes log probabilities of the reference model.
|
||||
"""
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with torch.no_grad(), ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_kl_logps,
|
||||
) = self.concatenated_forward(ref_model, batch)
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -167,25 +188,9 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
policy_chosen_logits,
|
||||
_,
|
||||
policy_kl_logps,
|
||||
) = self.forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_kl_logps,
|
||||
) = self.forward(ref_model, batch)
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(batch)
|
||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
|
|
Loading…
Reference in New Issue