diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index e554e45a..8a73c844 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -1,5 +1,6 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence +import torch from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils.versions import require_version @@ -10,6 +11,13 @@ if TYPE_CHECKING: from ...hparams import ModelArguments +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, leaf_modules) + + def add_z3_leaf_module(model: "PreTrainedModel") -> None: r""" Sets module as a leaf module to skip partitioning in deepspeed zero3. @@ -17,33 +25,30 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: if not is_deepspeed_zero3_enabled(): return - require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") - from deepspeed.utils import set_z3_leaf_modules # type: ignore - if getattr(model.config, "model_type", None) == "dbrx": from transformers.models.dbrx.modeling_dbrx import DbrxFFN - set_z3_leaf_modules(model, [DbrxFFN]) + _set_z3_leaf_modules(model, [DbrxFFN]) if getattr(model.config, "model_type", None) == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock - set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) if getattr(model.config, "model_type", None) == "jetmoe": from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE - set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) if getattr(model.config, "model_type", None) == "mixtral": from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) if getattr(model.config, "model_type", None) == "qwen2moe": from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: