fix jetmoe z3 block
This commit is contained in:
parent
a851056229
commit
d43822fcc2
|
@ -20,6 +20,21 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
|
||||
from deepspeed.utils import set_z3_leaf_modules # type: ignore
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jetmoe":
|
||||
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE
|
||||
|
||||
set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "mixtral":
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
|
@ -30,16 +45,6 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||
|
||||
set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [JambaSparseMoeBlock])
|
||||
|
||||
if getattr(model.config, "model_type", None) == "dbrx":
|
||||
from transformers.models.dbrx.modeling_dbrx import DbrxFFN
|
||||
|
||||
set_z3_leaf_modules(model, [DbrxFFN])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if model_args.moe_aux_loss_coef is not None:
|
||||
|
@ -49,5 +54,8 @@ def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_t
|
|||
elif getattr(config, "model_type", None) == "deepseek":
|
||||
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)
|
||||
|
||||
if getattr(config, "model_type", None) in ["dbrx", "jamba", "mixtral", "qwen2_moe"]:
|
||||
elif getattr(config, "model_type", None) == "jetmoe":
|
||||
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)
|
||||
|
||||
if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
|
||||
setattr(config, "output_router_logits", is_trainable)
|
||||
|
|
Loading…
Reference in New Issue