forked from p04798526/LLaMA-Factory-Mirror
fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval()
This commit is contained in:
parent
2ed8270112
commit
cf9f2d6c42
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
|
@ -10,7 +11,7 @@ from trl import DPOTrainer
|
|||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -61,6 +62,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
|
@ -176,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
|
@ -9,7 +10,7 @@ from trl import KTOTrainer
|
|||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -60,6 +61,8 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
|
@ -143,7 +146,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||
"""
|
||||
if self.ref_model is None:
|
||||
ref_model = model
|
||||
ref_context = get_ref_context(self.accelerator, model)
|
||||
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
|
||||
device_type = unwrapped_model.pretrained_model.device.type
|
||||
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -19,7 +18,6 @@ if is_galore_available():
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
@ -154,17 +152,6 @@ def create_reward_model(
|
|||
return reward_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
||||
r"""
|
||||
Gets adapter context for the reference model.
|
||||
"""
|
||||
with accelerator.unwrap_model(model).disable_adapter():
|
||||
model.eval()
|
||||
yield
|
||||
model.train()
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
|
|
Loading…
Reference in New Issue