This commit is contained in:
hiyouga 2023-10-26 17:49:41 +08:00
parent aff9363ce3
commit 838ed9aa87
3 changed files with 6 additions and 5 deletions

View File

@ -3,7 +3,7 @@ transformers>=4.31.0
datasets>=2.12.0
accelerate>=0.21.0
peft>=0.4.0
trl>=0.7.1
trl>=0.7.2
scipy
sentencepiece
protobuf

View File

@ -43,7 +43,7 @@ check_min_version("4.31.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
def load_model_and_tokenizer(

View File

@ -1,6 +1,6 @@
import torch
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs
):
if disable_dropout:
@ -32,6 +33,7 @@ class CustomDPOTrainer(DPOTrainer):
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = beta
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)
@ -40,8 +42,7 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)