fix KTO Trainer Sampler
This commit is contained in:
parent
0f01500b68
commit
34a2c5087a
|
@ -4,6 +4,7 @@ from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import KTOTrainer
|
from trl import KTOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
@ -173,6 +174,21 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
|
|
||||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||||
|
|
||||||
|
def has_length(self,dataset):
|
||||||
|
"""
|
||||||
|
Checks if the dataset implements __len__() and it doesn't raise an error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return len(dataset) is not None
|
||||||
|
except TypeError:
|
||||||
|
# TypeError: len() of unsized object
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.train_dataset is None or not self.has_length(self.train_dataset):
|
||||||
|
return None
|
||||||
|
return RandomSampler(self.train_dataset)
|
||||||
|
|
||||||
def get_batch_loss_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
|
Loading…
Reference in New Issue