From 67fe822324a9f830175e44f89acdd9d759b38852 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 00:50:32 +0800 Subject: [PATCH] fix #4090 --- requirements.txt | 2 +- src/llamafactory/extras/misc.py | 2 +- src/llamafactory/train/dpo/trainer.py | 29 ++++++++++----------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index 9e00555e..7b6cbee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.41.2 datasets>=2.16.0 accelerate>=0.30.1 peft>=0.11.1 -trl>=0.8.6 +trl>=0.9.3 gradio>=4.0.0 scipy einops diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 638c24cf..78f71847 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -65,7 +65,7 @@ def check_dependencies() -> None: require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") - require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") + require_version("trl>=0.9.3", "To fix: pip install trl>=0.9.3") def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index ec1de810..2bbe6a06 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -93,18 +93,6 @@ class CustomDPOTrainer(DPOTrainer): output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) - def sft_loss(self, batch: Dict[str, "torch.Tensor"], chosen_logits: "torch.FloatTensor") -> "torch.Tensor": - r""" - Computes supervised cross-entropy loss of given labels under the given logits. - - Returns: - A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. - """ - batch_size = batch["input_ids"].size(0) // 2 - chosen_labels, _ = batch["labels"].split(batch_size, dim=0) - chosen_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True) - return -chosen_logps - def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor": r""" Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model. @@ -156,9 +144,9 @@ class CustomDPOTrainer(DPOTrainer): def concatenated_forward( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] - ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: r""" - Computes the sum log probabilities of the labels under the given logits if loss_type != IPO. + Computes the sum log probabilities of the labels under given logits if loss_type is not IPO, ORPO or SimPO. Otherwise the average log probabilities. """ @@ -167,17 +155,20 @@ class CustomDPOTrainer(DPOTrainer): all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) - all_logps = self.get_batch_logps( + all_logps, valid_length = self.get_batch_logps( logits=all_logits, labels=batch["labels"], - average_log_prob=(self.loss_type in ["ipo", "orpo", "simpo"]), is_encoder_decoder=self.is_encoder_decoder, label_pad_token_id=self.label_pad_token_id, ) + if self.loss_type in ["ipo", "orpo", "simpo"]: + all_logps = all_logps / valid_length + batch_size = batch["input_ids"].size(0) // 2 chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) - return chosen_logps, rejected_logps, chosen_logits, rejected_logits + chosen_length, _ = valid_length.split(batch_size, dim=0) + return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length def compute_reference_log_probs( self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] @@ -201,6 +192,7 @@ class CustomDPOTrainer(DPOTrainer): reference_rejected_logps, _, _, + _, ) = self.concatenated_forward(ref_model, batch) return reference_chosen_logps, reference_rejected_logps @@ -220,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer): policy_rejected_logps, policy_chosen_logits, policy_rejected_logits, + policy_chosen_logps_avg, ) = self.concatenated_forward(model, batch) reference_chosen_logps, reference_rejected_logps = self.compute_reference_log_probs(model, batch) @@ -229,7 +222,7 @@ class CustomDPOTrainer(DPOTrainer): reference_chosen_logps, reference_rejected_logps, ) - sft_loss = self.sft_loss(batch, policy_chosen_logits) # compute chosen_logps with masks + sft_loss = -policy_chosen_logps_avg if self.ftx_gamma > 1e-6: losses += self.ftx_gamma * sft_loss