From 9f4fe623866b10b30c6418dee116b36671274f9f Mon Sep 17 00:00:00 2001 From: liuzc Date: Mon, 15 Apr 2024 12:11:49 +0800 Subject: [PATCH] fix: mixtral output_router_logits --- src/llmtuner/model/patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index a23d0ef3..e7807e56 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -316,7 +316,7 @@ def patch_config( if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn: setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn - if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable: + if getattr(config, "model_type", None) in ["mixtral", "qwen2_moe"] and is_trainable: setattr(config, "output_router_logits", True) init_kwargs["torch_dtype"] = model_args.compute_dtype