fix #1287
This commit is contained in:
parent
aff9363ce3
commit
838ed9aa87
|
@ -3,7 +3,7 @@ transformers>=4.31.0
|
||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.4.0
|
peft>=0.4.0
|
||||||
trl>=0.7.1
|
trl>=0.7.2
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
|
|
|
@ -43,7 +43,7 @@ check_min_version("4.31.0")
|
||||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
require_version("trl>=0.7.2", "To fix: pip install trl>=0.7.2")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
from transformers import BatchEncoding, Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer.utils import disable_dropout_in_model
|
from trl.trainer.utils import disable_dropout_in_model
|
||||||
|
@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: Optional[bool] = True,
|
disable_dropout: Optional[bool] = True,
|
||||||
|
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if disable_dropout:
|
if disable_dropout:
|
||||||
|
@ -32,6 +33,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
self.padding_value = 0
|
self.padding_value = 0
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
self.loss_type = loss_type
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
Trainer.__init__(self, model=model, **kwargs)
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
|
@ -40,8 +42,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
|
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
self.ref_model.eval()
|
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue