support mixtral
This commit is contained in:
parent
f4657de7d5
commit
96380f5e18
19
README.md
19
README.md
|
@ -55,9 +55,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework.
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
||||||
|
|
||||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
|
@ -101,6 +103,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
@ -206,13 +209,13 @@ huggingface-cli login
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
| Method | Bits | 7B | 13B | 30B | 65B |
|
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| ------ | ---- | ----- | ----- | ----- | ------ |
|
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB |
|
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
|
|
19
README_zh.md
19
README_zh.md
|
@ -55,9 +55,11 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。
|
||||||
|
|
||||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
||||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune,例如 `--neftune_noise_alpha 5`。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
|
@ -101,6 +103,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
|
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
@ -206,13 +209,13 @@ huggingface-cli login
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B |
|
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||||
| ------- | ---- | ----- | ----- | ----- | ------ |
|
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB |
|
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 1000GB |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB | 200GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB | 120GB |
|
||||||
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB | 80GB |
|
||||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 32GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.31.0,<4.35.0
|
transformers>=4.36.0
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.7.0
|
peft>=0.7.0
|
||||||
|
|
|
@ -541,9 +541,7 @@ register_template(
|
||||||
"[INST] {{query}} [/INST]"
|
"[INST] {{query}} [/INST]"
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[
|
sep=[]
|
||||||
" "
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -382,6 +382,22 @@ register_model_group(
|
||||||
"Mistral-7B-Chat": {
|
"Mistral-7B-Chat": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
||||||
|
},
|
||||||
|
"Mistral-7B-v0.2-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
template="mistral"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Mixtral-8x7B": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1"
|
||||||
|
},
|
||||||
|
"Mixtral-8x7B-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
template="mistral"
|
template="mistral"
|
||||||
|
|
|
@ -25,7 +25,6 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype, try_download_model_from_ms
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.model.adapter import init_adapter
|
from llmtuner.model.adapter import init_adapter
|
||||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
||||||
|
@ -38,7 +37,7 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
require_version("transformers>=4.36.0", "To fix: pip install transformers>=4.36.0")
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
|
@ -124,28 +123,22 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if not is_flash_attn2_available():
|
||||||
if is_flash_attn2_available():
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
elif getattr(config, "model_type", None) == "qwen":
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
else:
|
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support FlashAttention.")
|
setattr(config, "attn_implementation", "flash_attention_2")
|
||||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
|
||||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
|
||||||
|
|
||||||
# Set shift short attention (S^2-Attn)
|
# Set shift short attention (S^2-Attn)
|
||||||
if is_trainable and model_args.shift_attn:
|
if is_trainable and model_args.shift_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
# if getattr(config, "model_type", None) == "llama":
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
# setattr(config, "group_size_ratio", 0.25)
|
||||||
else:
|
# logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
logger.warning("Current model does not support shift short attention.")
|
# else:
|
||||||
|
# logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
# Quantization configurations (using gptq or awq)
|
# Quantization configurations (using gptq or awq)
|
||||||
if getattr(config, "quantization_config", None):
|
if getattr(config, "quantization_config", None):
|
||||||
|
|
Loading…
Reference in New Issue