remove gc warnings in DPO&KTO

This commit is contained in:
hiyouga 2024-06-03 22:53:54 +08:00
parent 30a538e2db
commit f9a206509e
3 changed files with 20 additions and 6 deletions

View File

@ -10,7 +10,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING: if TYPE_CHECKING:
@ -69,6 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor from badam import clip_grad_norm_for_sparse_tensor
@ -189,7 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None: if self.ref_model is None:
ref_model = model ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter() ref_context = get_ref_context(self.accelerator, model)
else: else:
ref_model = self.ref_model ref_model = self.ref_model
ref_context = nullcontext() ref_context = nullcontext()

View File

@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING: if TYPE_CHECKING:
@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model) self.ref_model = self._prepare_deepspeed(self.ref_model)
else: else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor from badam import clip_grad_norm_for_sparse_tensor
@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
""" """
if self.ref_model is None: if self.ref_model is None:
ref_model = model ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter() ref_context = get_ref_context(self.accelerator, model)
else: else:
ref_model = self.ref_model ref_model = self.ref_model
ref_context = nullcontext() ref_context = nullcontext()

View File

@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch import torch
@ -17,8 +18,8 @@ if is_galore_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from accelerate import Accelerator
from transformers.modeling_utils import PreTrainedModel from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments from ..hparams import DataArguments
@ -156,6 +157,17 @@ def create_reward_model(
return reward_model return reward_model
@contextmanager
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
r"""
Gets adapter context for the reference model.
"""
with accelerator.unwrap_model(model).disable_adapter():
model.eval()
yield
model.train()
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r""" r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)