This commit is contained in:
hiyouga 2023-10-26 17:49:41 +08:00
parent aff9363ce3
commit 838ed9aa87
3 changed files with 6 additions and 5 deletions

View File

@ -3,7 +3,7 @@ transformers>=4.31.0
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.7.1 trl>=0.7.2
scipy scipy
sentencepiece sentencepiece
protobuf protobuf

View File

@ -43,7 +43,7 @@ check_min_version("4.31.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1") require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@ -1,6 +1,6 @@
import torch import torch
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model from trl.trainer.utils import disable_dropout_in_model
@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True, disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs **kwargs
): ):
if disable_dropout: if disable_dropout:
@ -32,6 +33,7 @@ class CustomDPOTrainer(DPOTrainer):
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.beta = beta self.beta = beta
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
@ -40,8 +42,7 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)