fix KTO Trainer Sampler
This commit is contained in:
parent
0f01500b68
commit
34a2c5087a
|
@ -4,6 +4,7 @@ from types import MethodType
|
|||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import RandomSampler
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
@ -173,6 +174,21 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
|
||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||
|
||||
def has_length(self,dataset):
|
||||
"""
|
||||
Checks if the dataset implements __len__() and it doesn't raise an error
|
||||
"""
|
||||
try:
|
||||
return len(dataset) is not None
|
||||
except TypeError:
|
||||
# TypeError: len() of unsized object
|
||||
return False
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.train_dataset is None or not self.has_length(self.train_dataset):
|
||||
return None
|
||||
return RandomSampler(self.train_dataset)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
|
Loading…
Reference in New Issue