This commit is contained in:
hiyouga 2024-06-07 04:18:05 +08:00
parent ccc8b64cc2
commit f9e818d79c
7 changed files with 47 additions and 54 deletions

View File

@ -298,7 +298,7 @@ huggingface-cli login
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.9.3 | 0.9.3 | | trl | 0.8.6 | 0.9.3 |
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |

View File

@ -298,7 +298,7 @@ huggingface-cli login
| datasets | 2.16.0 | 2.19.2 | | datasets | 2.16.0 | 2.19.2 |
| accelerate | 0.30.1 | 0.30.1 | | accelerate | 0.30.1 | 0.30.1 |
| peft | 0.11.1 | 0.11.1 | | peft | 0.11.1 | 0.11.1 |
| trl | 0.9.3 | 0.9.3 | | trl | 0.8.6 | 0.9.3 |
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |

View File

@ -2,7 +2,7 @@ transformers>=4.41.2
datasets>=2.16.0 datasets>=2.16.0
accelerate>=0.30.1 accelerate>=0.30.1
peft>=0.11.1 peft>=0.11.1
trl>=0.9.3 trl>=0.8.6
gradio>=4.0.0 gradio>=4.0.0
scipy scipy
einops einops

View File

@ -65,7 +65,7 @@ def check_dependencies() -> None:
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
require_version("trl>=0.9.3", "To fix: pip install trl>=0.9.3") require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:

View File

@ -10,7 +10,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
if TYPE_CHECKING: if TYPE_CHECKING:
@ -155,12 +155,7 @@ class CustomDPOTrainer(DPOTrainer):
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = self.get_batch_logps( all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
logits=all_logits,
labels=batch["labels"],
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if self.loss_type in ["ipo", "orpo", "simpo"]: if self.loss_type in ["ipo", "orpo", "simpo"]:
all_logps = all_logps / valid_length all_logps = all_logps / valid_length

View File

@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
if TYPE_CHECKING: if TYPE_CHECKING:
@ -98,16 +98,6 @@ class CustomKTOTrainer(KTOTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
return -all_logps
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
@ -127,28 +117,23 @@ class CustomKTOTrainer(KTOTrainer):
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps = self.get_batch_logps( logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
logits=logits, return logps, logps / valid_length
labels=batch["{}labels".format(prefix)],
average_log_prob=False,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
return logits, logps
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logits, target_logps = self.forward(model, batch) target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
_, kl_logps = self.forward(model, batch, prefix="kl_") kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]] chosen_logps = target_logps[batch["kto_tags"]]
chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]] rejected_logps = target_logps[~batch["kto_tags"]]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
@ -164,13 +149,9 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
( reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward(
reference_chosen_logps, ref_model, batch
reference_rejected_logps, )
_,
_,
reference_kl_logps,
) = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@ -183,14 +164,9 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
metrics = {} metrics = {}
( policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = (
policy_chosen_logps, self.concatenated_forward(model, batch)
policy_rejected_logps, )
policy_chosen_logits,
_,
policy_kl_logps,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch model, batch
) )
@ -205,8 +181,8 @@ class CustomKTOTrainer(KTOTrainer):
losses = losses.nanmean() losses = losses.nanmean()
if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale if self.ftx_gamma > 1e-6 and len(policy_chosen_logps) > 0: # remember to rescale
sft_loss = self.sft_loss(policy_chosen_logits, batch["labels"][batch["kto_tags"]]) sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logits) * len(batch["labels"]) losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

View File

@ -1,5 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -7,6 +7,7 @@ from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
@ -399,3 +400,24 @@ def create_custom_scheduler(
for param in optimizer_dict.keys(): for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook) param.register_post_accumulate_grad_hook(scheduler_hook)
def get_batch_logps(
logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
r"""
Computes the log probabilities of the given labels under the given logits.
Returns:
logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0 # dummy token
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)