From 18923b14026d88cac2631be1e5d05ba001f69ae6 Mon Sep 17 00:00:00 2001 From: ldwang Date: Wed, 24 Jan 2024 14:43:16 +0800 Subject: [PATCH] Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3. Signed-off-by: ldwang --- src/llmtuner/model/patcher.py | 39 +++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 3b6f000a..2c7c14b3 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -268,6 +268,42 @@ def patch_config( _configure_quantization(config, tokenizer, model_args, config_kwargs) +def patch_mixtral_replace_moe_impl() -> None: + def mlp_forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + ## Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py + def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False) + topk_weight /= topk_weight.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + topk_weight = topk_weight.to(hidden_states.dtype) + + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states) + flat_topk_idx = topk_idx.view(-1) + for i in range(self.num_experts): + expert = self.experts[i] + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP + + MixtralBLockSparseTop2MLP.forward = mlp_forward + MixtralSparseMoeBlock.forward = moe_forward + + def patch_model( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool ) -> None: @@ -290,6 +326,9 @@ def patch_model( from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + if is_trainable: + patch_mixtral_replace_moe_impl() + def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: