DeepSpeed ZeRO3 has inflight param error when calling model.eval()
This commit is contained in:
hiyouga 2024-06-13 02:25:50 +08:00
parent 2ed8270112
commit cf9f2d6c42
4 changed files with 12 additions and 17 deletions

View File

@ -1,3 +1,4 @@
import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@ -10,7 +11,7 @@ from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@ -61,6 +62,8 @@ class CustomDPOTrainer(DPOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
warnings.simplefilter("ignore") # remove gc warnings on ref model
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@ -176,7 +179,7 @@ class CustomDPOTrainer(DPOTrainer):
if self.ref_model is None:
ref_model = model
ref_context = get_ref_context(self.accelerator, model)
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()

View File

@ -1,3 +1,4 @@
import warnings
from collections import defaultdict
from contextlib import nullcontext
from types import MethodType
@ -9,7 +10,7 @@ from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from ...extras.constants import IGNORE_INDEX
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps, get_ref_context
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
if TYPE_CHECKING:
@ -60,6 +61,8 @@ class CustomKTOTrainer(KTOTrainer):
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
warnings.simplefilter("ignore") # remove gc warnings on ref model
if ref_model is not None:
if self.is_deepspeed_enabled:
if not (
@ -143,7 +146,7 @@ class CustomKTOTrainer(KTOTrainer):
"""
if self.ref_model is None:
ref_model = model
ref_context = get_ref_context(self.accelerator, model)
ref_context = self.accelerator.unwrap_model(model).disable_adapter()
else:
ref_model = self.ref_model
ref_context = nullcontext()

View File

@ -1,6 +1,7 @@
import math
import os
import sys
import warnings
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@ -136,6 +137,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
device_type = unwrapped_model.pretrained_model.device.type
self.amp_context = torch.autocast(device_type, dtype=model_args.compute_dtype)
warnings.simplefilter("ignore") # remove gc warnings on ref model
if finetuning_args.reward_model_type == "full":
if self.is_deepspeed_enabled:

View File

@ -1,4 +1,3 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
@ -19,7 +18,6 @@ if is_galore_available():
if TYPE_CHECKING:
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
@ -154,17 +152,6 @@ def create_reward_model(
return reward_model
@contextmanager
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
r"""
Gets adapter context for the reference model.
"""
with accelerator.unwrap_model(model).disable_adapter():
model.eval()
yield
model.train()
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)