From 34a2c5087a174a807e5a11cae3748bcaaaf13550 Mon Sep 17 00:00:00 2001 From: "enji.zhou" Date: Mon, 3 Jun 2024 21:32:38 +0800 Subject: [PATCH 1/2] fix KTO Trainer Sampler --- src/llamafactory/train/kto/trainer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 82ae722d..3f1220a9 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -4,6 +4,7 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch +from torch.utils.data import RandomSampler from transformers import Trainer from trl import KTOTrainer from trl.trainer import disable_dropout_in_model @@ -173,6 +174,21 @@ class CustomKTOTrainer(KTOTrainer): return reference_chosen_logps, reference_rejected_logps, reference_kl_logps + def has_length(self,dataset): + """ + Checks if the dataset implements __len__() and it doesn't raise an error + """ + try: + return len(dataset) is not None + except TypeError: + # TypeError: len() of unsized object + return False + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not self.has_length(self.train_dataset): + return None + return RandomSampler(self.train_dataset) + def get_batch_loss_metrics( self, model: "PreTrainedModel", From 24499f40dc1d9db448a3328d2a75c60eec27feb9 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 3 Jun 2024 22:08:38 +0800 Subject: [PATCH 2/2] Update trainer.py --- src/llamafactory/train/kto/trainer.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3f1220a9..7c0343f5 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -4,7 +4,6 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch -from torch.utils.data import RandomSampler from transformers import Trainer from trl import KTOTrainer from trl.trainer import disable_dropout_in_model @@ -14,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -85,6 +85,12 @@ class CustomKTOTrainer(KTOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) if self.processor is not None: @@ -174,21 +180,6 @@ class CustomKTOTrainer(KTOTrainer): return reference_chosen_logps, reference_rejected_logps, reference_kl_logps - def has_length(self,dataset): - """ - Checks if the dataset implements __len__() and it doesn't raise an error - """ - try: - return len(dataset) is not None - except TypeError: - # TypeError: len() of unsized object - return False - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.train_dataset is None or not self.has_length(self.train_dataset): - return None - return RandomSampler(self.train_dataset) - def get_batch_loss_metrics( self, model: "PreTrainedModel",