fix #1287
This commit is contained in:
parent
aff9363ce3
commit
838ed9aa87
|
@ -3,7 +3,7 @@ transformers>=4.31.0
|
|||
datasets>=2.12.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.7.1
|
||||
trl>=0.7.2
|
||||
scipy
|
||||
sentencepiece
|
||||
protobuf
|
||||
|
|
|
@ -43,7 +43,7 @@ check_min_version("4.31.0")
|
|||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
||||
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
|
@ -32,6 +33,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
|
@ -40,8 +42,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model.eval()
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue