remove gc warnings in DPO&KTO
This commit is contained in:
parent
30a538e2db
commit
f9a206509e
|
@ -10,7 +10,7 @@ from trl import DPOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -69,6 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
self.ref_model.eval()
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
|
@ -189,7 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
|
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = model
|
ref_model = model
|
||||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
ref_context = get_ref_context(self.accelerator, model)
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
|
|
@ -9,7 +9,7 @@ from trl import KTOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -68,6 +68,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
self.ref_model.eval()
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
|
@ -164,7 +165,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
"""
|
"""
|
||||||
if self.ref_model is None:
|
if self.ref_model is None:
|
||||||
ref_model = model
|
ref_model = model
|
||||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
ref_context = get_ref_context(self.accelerator, model)
|
||||||
else:
|
else:
|
||||||
ref_model = self.ref_model
|
ref_model = self.ref_model
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,8 +18,8 @@ if is_galore_available():
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from accelerate import Accelerator
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
@ -156,6 +157,17 @@ def create_reward_model(
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
||||||
|
r"""
|
||||||
|
Gets adapter context for the reference model.
|
||||||
|
"""
|
||||||
|
with accelerator.unwrap_model(model).disable_adapter():
|
||||||
|
model.eval()
|
||||||
|
yield
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
|
||||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||||
|
|
Loading…
Reference in New Issue