diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 542335a3..ec1de810 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -10,7 +10,7 @@ from trl import DPOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context if TYPE_CHECKING: @@ -69,6 +69,7 @@ class CustomDPOTrainer(DPOTrainer): self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -189,7 +190,7 @@ class CustomDPOTrainer(DPOTrainer): if self.ref_model is None: ref_model = model - ref_context = self.accelerator.unwrap_model(model).disable_adapter() + ref_context = get_ref_context(self.accelerator, model) else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 7c0343f5..f29945f5 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -9,7 +9,7 @@ from trl import KTOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context if TYPE_CHECKING: @@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer): self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + self.ref_model.eval() if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor @@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer): """ if self.ref_model is None: ref_model = model - ref_context = self.accelerator.unwrap_model(model).disable_adapter() + ref_context = get_ref_context(self.accelerator, model) else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/utils.py index 230fdc1e..2b33af1c 100644 --- a/src/llamafactory/train/utils.py +++ b/src/llamafactory/train/utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union import torch @@ -17,8 +18,8 @@ if is_galore_available(): if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments - from transformers.modeling_utils import PreTrainedModel + from accelerate import Accelerator + from transformers import PreTrainedModel, Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead from ..hparams import DataArguments @@ -156,6 +157,17 @@ def create_reward_model( return reward_model +@contextmanager +def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"): + r""" + Gets adapter context for the reference model. + """ + with accelerator.unwrap_model(model).disable_adapter(): + model.eval() + yield + model.train() + + def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: r""" Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)