diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index d860b29a..5bdb9c43 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType @@ -10,7 +11,7 @@ from trl import DPOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -61,6 +62,8 @@ class CustomDPOTrainer(DPOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -176,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer): if self.ref_model is None: ref_model = model - ref_context = get_ref_context(self.accelerator, model) + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 22a84e4a..3b4488fc 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from contextlib import nullcontext from types import MethodType @@ -9,7 +10,7 @@ from trl import KTOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps if TYPE_CHECKING: @@ -60,6 +61,8 @@ class CustomKTOTrainer(KTOTrainer): if not hasattr(self, "accelerator"): raise AttributeError("Please update `transformers`.") + warnings.simplefilter("ignore") # remove gc warnings on ref model + if ref_model is not None: if self.is_deepspeed_enabled: if not ( @@ -143,7 +146,7 @@ class CustomKTOTrainer(KTOTrainer): """ if self.ref_model is None: ref_model = model - ref_context = get_ref_context(self.accelerator, model) + ref_context = self.accelerator.unwrap_model(model).disable_adapter() else: ref_model = self.ref_model ref_context = nullcontext() diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 2e1288e4..737c45a3 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -1,6 +1,7 @@ import math import os import sys +import warnings from types import MethodType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): device_type = unwrapped_model.pretrained_model.device.type self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype) + warnings.simplefilter("ignore") # remove gc warnings on ref model if finetuning_args.reward_model_type == "full": if self.is_deepspeed_enabled: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 7e9cc881..48944a63 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch @@ -19,7 +18,6 @@ if is_galore_available(): if TYPE_CHECKING: - from accelerate import Accelerator from transformers import PreTrainedModel, Seq2SeqTrainingArguments from trl import AutoModelForCausalLMWithValueHead @@ -154,17 +152,6 @@ def create_reward_model( return reward_model -@contextmanager -def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"): - r""" - Gets adapter context for the reference model. - """ - with accelerator.unwrap_model(model).disable_adapter(): - model.eval() - yield - model.train() - - def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: r""" Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)