From 96380f5e1887bb166be339e58ab8f65e464d4010 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 12 Dec 2023 11:39:04 +0800 Subject: [PATCH] support mixtral --- README.md | 19 +++++++++++-------- README_zh.md | 19 +++++++++++-------- requirements.txt | 2 +- src/llmtuner/data/template.py | 4 +--- src/llmtuner/extras/constants.py | 16 ++++++++++++++++ src/llmtuner/model/loader.py | 31 ++++++++++++------------------- 6 files changed, 52 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 0140f614..665c1426 100644 --- a/README.md +++ b/README.md @@ -55,9 +55,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. + [23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage. -[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`. +[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
Full Changelog @@ -101,6 +103,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | +| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | @@ -206,13 +209,13 @@ huggingface-cli login ### Hardware Requirement -| Method | Bits | 7B | 13B | 30B | 65B | -| ------ | ---- | ----- | ----- | ----- | ------ | -| Full | 16 | 160GB | 320GB | 600GB | 1200GB | -| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | -| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | -| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | +| Method | Bits | 7B | 13B | 30B | 65B | 8x7B | +| ------ | ---- | ----- | ----- | ----- | ------ | ------ | +| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB | +| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB | +| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB | +| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB | ## Getting Started diff --git a/README_zh.md b/README_zh.md index ce8e99ac..5151ab99 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,9 +55,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ## 更新日志 +[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。 + [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 -[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。 +[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
展开日志 @@ -101,6 +103,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 | | [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral | +| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - | | [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse | @@ -206,13 +209,13 @@ huggingface-cli login ### 硬件依赖 -| 训练方法 | 精度 | 7B | 13B | 30B | 65B | -| ------- | ---- | ----- | ----- | ----- | ------ | -| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | -| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | -| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | -| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | -| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | +| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B | +| ------- | ---- | ----- | ----- | ----- | ------ | ------ | +| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB | +| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB | +| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB | +| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB | +| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB | ## 如何使用 diff --git a/requirements.txt b/requirements.txt index da0b2d88..f56d8cdc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=1.13.1 -transformers>=4.31.0,<4.35.0 +transformers>=4.36.0 datasets>=2.14.3 accelerate>=0.21.0 peft>=0.7.0 diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 4f8d6a51..2b25a7fd 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -541,9 +541,7 @@ register_template( "[INST] {{query}} [/INST]" ], system="", - sep=[ - " " - ] + sep=[] ) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index ee506b65..f73df95d 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -382,6 +382,22 @@ register_model_group( "Mistral-7B-Chat": { DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1" + }, + "Mistral-7B-v0.2-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2" + } + }, + template="mistral" +) + + +register_model_group( + models={ + "Mixtral-8x7B": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1" + }, + "Mixtral-8x7B-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1" } }, template="mistral" diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 3df33c70..082ee6aa 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -25,7 +25,6 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms from llmtuner.extras.packages import is_flash_attn2_available -from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments from llmtuner.model.adapter import init_adapter from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer @@ -38,7 +37,7 @@ if TYPE_CHECKING: logger = get_logger(__name__) -require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"") +require_version("transformers>=4.36.0", "To fix: pip install transformers>=4.36.0") require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") @@ -124,28 +123,22 @@ def load_model_and_tokenizer( # Set FlashAttention-2 if model_args.flash_attn: - if getattr(config, "model_type", None) == "llama": - if is_flash_attn2_available(): - LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 - LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask - logger.info("Using FlashAttention-2 for faster training and inference.") - else: - logger.warning("FlashAttention-2 is not installed.") - elif getattr(config, "model_type", None) in ["qwen", "Yi"]: + if not is_flash_attn2_available(): + logger.warning("FlashAttention-2 is not installed.") + elif getattr(config, "model_type", None) == "qwen": logger.info("Current model automatically enables FlashAttention if installed.") else: - logger.warning("Current model does not support FlashAttention.") - elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": - LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention - logger.warning("Using `--flash_attn` for faster training in large context length.") + setattr(config, "attn_implementation", "flash_attention_2") + logger.info("Using FlashAttention-2 for faster training and inference.") # Set shift short attention (S^2-Attn) if is_trainable and model_args.shift_attn: - if getattr(config, "model_type", None) == "llama": - setattr(config, "group_size_ratio", 0.25) - logger.info("Using shift short attention with group_size_ratio=1/4.") - else: - logger.warning("Current model does not support shift short attention.") + logger.warning("Shift short attention is temporarily invalid due to breaking changes.") + # if getattr(config, "model_type", None) == "llama": + # setattr(config, "group_size_ratio", 0.25) + # logger.info("Using shift short attention with group_size_ratio=1/4.") + # else: + # logger.warning("Current model does not support shift short attention.") # Quantization configurations (using gptq or awq) if getattr(config, "quantization_config", None):