diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 82ae722d..3f1220a9 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -4,6 +4,7 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch +from torch.utils.data import RandomSampler from transformers import Trainer from trl import KTOTrainer from trl.trainer import disable_dropout_in_model @@ -173,6 +174,21 @@ class CustomKTOTrainer(KTOTrainer): return reference_chosen_logps, reference_rejected_logps, reference_kl_logps + def has_length(self,dataset): + """ + Checks if the dataset implements __len__() and it doesn't raise an error + """ + try: + return len(dataset) is not None + except TypeError: + # TypeError: len() of unsized object + return False + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not self.has_length(self.train_dataset): + return None + return RandomSampler(self.train_dataset) + def get_batch_loss_metrics( self, model: "PreTrainedModel",