remove gc warnings in DPO&KTO
This commit is contained in:
parent
30a538e2db
commit
f9a206509e
|
@ -10,7 +10,7 @@ from trl import DPOTrainer
|
|||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -69,6 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
@ -189,7 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
|
|
@ -9,7 +9,7 @@ from trl import KTOTrainer
|
|||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
"""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -17,8 +18,8 @@ if is_galore_available():
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from accelerate import Accelerator
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import DataArguments
|
||||
|
@ -156,6 +157,17 @@ def create_reward_model(
|
|||
return reward_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
||||
r"""
|
||||
Gets adapter context for the reference model.
|
||||
"""
|
||||
with accelerator.unwrap_model(model).disable_adapter():
|
||||
model.eval()
|
||||
yield
|
||||
model.train()
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
|
|
Loading…
Reference in New Issue