From b87c74289d523ef88611b376074199ffd03cf103 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 16 Dec 2023 19:21:41 +0800 Subject: [PATCH] support dpo-ftx --- requirements.txt | 2 +- src/llmtuner/extras/packages.py | 29 +++------ src/llmtuner/hparams/finetuning_args.py | 8 +++ src/llmtuner/model/loader.py | 2 +- src/llmtuner/train/dpo/trainer.py | 85 ++++++++++++++++++++++++- src/llmtuner/train/dpo/workflow.py | 2 + 6 files changed, 103 insertions(+), 25 deletions(-) diff --git a/requirements.txt b/requirements.txt index bee9939d..7a9ccdf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.36.1 datasets>=2.14.3 accelerate>=0.21.0 peft>=0.7.0 -trl>=0.7.4 +trl==0.7.4 gradio>=3.38.0,<4.0.0 scipy sentencepiece diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 22d725c2..4fd53346 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -13,48 +13,37 @@ def get_package_version(name: str) -> str: return "0.0.0" -_fastapi_available = is_package_available("fastapi") -_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") -_jieba_available = is_package_available("jieba") -_matplotlib_available = is_package_available("matplotlib") -_nltk_available = is_package_available("nltk") -_requests_available = is_package_available("requests") -_rouge_available = is_package_available("rouge_chinese") -_starlette_available = is_package_available("sse_starlette") -_uvicorn_available = is_package_available("uvicorn") - - def is_fastapi_availble(): - return _fastapi_available + return is_package_available("fastapi") def is_flash_attn2_available(): - return _flash_attn2_available + return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") def is_jieba_available(): - return _jieba_available + return is_package_available("jieba") def is_matplotlib_available(): - return _matplotlib_available + return is_package_available("matplotlib") def is_nltk_available(): - return _nltk_available + return is_package_available("nltk") def is_requests_available(): - return _requests_available + return is_package_available("requests") def is_rouge_available(): - return _rouge_available + return is_package_available("rouge_chinese") def is_starlette_available(): - return _starlette_available + return is_package_available("sse_starlette") def is_uvicorn_available(): - return _uvicorn_available + return is_package_available("uvicorn") diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 93d3be91..7af896fa 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -70,6 +70,14 @@ class RLHFArguments: default=0.1, metadata={"help": "The beta parameter for the DPO loss."} ) + dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field( + default="sigmoid", + metadata={"help": "The type of DPO loss to use."} + ) + dpo_ftx: Optional[float] = field( + default=0, + metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."} + ) ppo_buffer_size: Optional[int] = field( default=1, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."} diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 72ca9782..4371a0f4 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -24,7 +24,7 @@ require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36. require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") -require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4") +require_version("trl==0.7.4", "To fix: pip install trl==0.7.4") def load_model_and_tokenizer( diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index ccf49a7f..4de79d82 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer): def __init__( self, beta: float, + loss_type: Literal["sigmoid", "hinge"], + ftx_gamma: float, model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, disable_dropout: Optional[bool] = True, - loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid", **kwargs ): if disable_dropout: @@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer): self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 self.beta = beta + self.label_smoothing = 0 + self.ftx_gamma = ftx_gamma self.loss_type = loss_type self._stored_metrics = defaultdict(lambda: defaultdict(list)) @@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + def sft_loss( + self, + chosen_logits: torch.FloatTensor, + chosen_labels: torch.LongTensor + ) -> torch.Tensor: + r""" + Computes supervised cross-entropy loss of given labels under the given logits. + + Returns: + A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. + """ + all_logps = self._get_batch_logps( + chosen_logits, + chosen_labels, + average_log_prob=True + ) + return -all_logps + def concatenated_forward( self, - model: Optional[torch.nn.Module] = None, - batch: Optional[Dict[str, torch.Tensor]] = None + model: "PreTrainedModel", + batch: Dict[str, torch.Tensor] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error @@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer): chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) return chosen_logps, rejected_logps, chosen_logits, rejected_logits + + def get_batch_metrics( + self, + model: "PreTrainedModel", + batch: Dict[str, torch.Tensor], + train_eval: Optional[Literal["train", "eval"]] = "train" + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + r""" + Computes the DPO loss and other metrics for the given batch of inputs for train or test. + """ + metrics = {} + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + with torch.no_grad(): + if self.ref_model is None: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + if self.ftx_gamma > 1e-6: + batch_size = batch["input_ids"].size(0) // 2 + chosen_labels, _ = batch["labels"].split(batch_size, dim=0) + losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() + + return losses.mean(), metrics diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 7ce2d44c..12a6b545 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -47,6 +47,8 @@ def run_dpo( # Initialize our Trainer trainer = CustomDPOTrainer( beta=finetuning_args.dpo_beta, + loss_type=finetuning_args.dpo_loss, + ftx_gamma=finetuning_args.dpo_ftx, model=model, ref_model=ref_model, args=training_args,