Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>
This commit is contained in:
parent
18923b1402
commit
c284665425
|
@ -269,6 +269,7 @@ def patch_config(
|
||||||
|
|
||||||
|
|
||||||
def patch_mixtral_replace_moe_impl() -> None:
|
def patch_mixtral_replace_moe_impl() -> None:
|
||||||
|
import torch.nn.functional as F
|
||||||
def mlp_forward(self, hidden_states):
|
def mlp_forward(self, hidden_states):
|
||||||
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
|
||||||
current_hidden_states = self.w2(current_hidden_states)
|
current_hidden_states = self.w2(current_hidden_states)
|
||||||
|
|
Loading…
Reference in New Issue