diff --git a/requirements.txt b/requirements.txt index 83523eae..61d5d279 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ transformers>=4.31.0 datasets>=2.12.0 accelerate>=0.21.0 peft>=0.4.0 -trl>=0.7.1 +trl>=0.7.2 scipy sentencepiece protobuf diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index c843e8b1..a590c94b 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -43,7 +43,7 @@ check_min_version("4.31.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") -require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1") +require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2") def load_model_and_tokenizer( diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index c1d2f054..bde02327 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -1,6 +1,6 @@ import torch from collections import defaultdict -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union from transformers import BatchEncoding, Trainer from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model @@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer): model: Union["PreTrainedModel", torch.nn.Module], ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, disable_dropout: Optional[bool] = True, + loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid", **kwargs ): if disable_dropout: @@ -32,6 +33,7 @@ class CustomDPOTrainer(DPOTrainer): self.label_pad_token_id = IGNORE_INDEX self.padding_value = 0 self.beta = beta + self.loss_type = loss_type self._stored_metrics = defaultdict(lambda: defaultdict(list)) Trainer.__init__(self, model=model, **kwargs) @@ -40,8 +42,7 @@ class CustomDPOTrainer(DPOTrainer): if ref_model is not None: if self.is_deepspeed_enabled: - self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model) - self.ref_model.eval() + self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)