fix KTO Trainer Sampler

This commit is contained in:
enji.zhou 2024-06-03 21:32:38 +08:00
parent 0f01500b68
commit 34a2c5087a
1 changed files with 16 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
import torch
from torch.utils.data import RandomSampler
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
@ -173,6 +174,21 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
def has_length(self,dataset):
"""
Checks if the dataset implements __len__() and it doesn't raise an error
"""
try:
return len(dataset) is not None
except TypeError:
# TypeError: len() of unsized object
return False
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not self.has_length(self.train_dataset):
return None
return RandomSampler(self.train_dataset)
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",