remove gc warnings in DPO&KTO

This commit is contained in:
hiyouga 2024-06-03 22:53:54 +08:00
parent 30a538e2db
commit f9a206509e
3 changed files with 20 additions and 6 deletions

View File

@ -10,7 +10,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:
@ -69,6 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
@ -189,7 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
ref_context = get_ref_context(self.accelerator, model)
else:
ref_model = self.ref_model
ref_context = nullcontext()

View File

@ -9,7 +9,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..utils import create_custom_optimzer, create_custom_scheduler
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
if TYPE_CHECKING:
@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
self.ref_model.eval()
if finetuning_args.use_badam:
from badam import clip_grad_norm_for_sparse_tensor
@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
ref_context = get_ref_context(self.accelerator, model)
else:
ref_model = self.ref_model
ref_context = nullcontext()

View File

@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch
@ -17,8 +18,8 @@ if is_galore_available():
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers.modeling_utils import PreTrainedModel
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import DataArguments
@ -156,6 +157,17 @@ def create_reward_model(
return reward_model
@contextmanager
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
r"""
Gets adapter context for the reference model.
"""
with accelerator.unwrap_model(model).disable_adapter():
model.eval()
yield
model.train()
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)