Added Mixture of Depths

This commit is contained in:
Marco 2024-04-18 20:31:24 +02:00
parent 2aaaede247
commit 620add7b9f
10 changed files with 103 additions and 6 deletions

View File

@ -46,7 +46,7 @@ Choose your path:
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Advanced algorithms**: GaLore, Mixture of Depths, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@ -68,14 +68,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/04/19] We integrated **[Mixture of Depths](https://github.com/astramind-ai/Mixture-of-depths)**. see `examples/extras/MoD` for usage.
[24/04/19] We supported **Meta Llama 3** model series.
[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage.
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
<details><summary>Full Changelog</summary>
[24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
[24/03/31] We supported **[ORPO](https://arxiv.org/abs/2403.07691)**. See `examples/lora_single_gpu` for usage.
[24/03/21] Our paper "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" is available at arXiv!

View File

@ -46,7 +46,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练、DPO 训练和 ORPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **先进算法**GaLore、Mixture of Depths、BAdam、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@ -68,14 +68,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/04/19] 我们整合了 **[深度混合](https://github.com/astramind-ai/Mixture-of-depths)**。用法请参见 `examples/extras/MoD`
[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。
[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
<details><summary>展开日志</summary>
[24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练24GB 可训练 Llama-2-7B-56k。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[24/03/31] 我们支持了 **[ORPO](https://arxiv.org/abs/2403.07691)**。详细用法请参照 `examples/lora_single_gpu`
[24/03/21] 我们的论文 "[LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models](https://arxiv.org/abs/2403.13372)" 可在 arXiv 上查看!

View File

@ -41,6 +41,9 @@ examples/
├── llama_pro/
│ ├── expand.sh: Expand layers in the model
│ └── sft.sh: Fine-tune the expanded model
├── MoD/
│ ├── freeze_sft.sh: Freeze finetune a model, updating only the MoD router
│ └── sft.sh: Fine-tune the MoD model
└── fsdp_qlora/
└── sft.sh: Fine-tune quantized model with FSDP+QLoRA
```

View File

@ -41,6 +41,9 @@ examples/
├── llama_pro/
│ ├── expand.sh: 扩展模型中的层
│ └── sft.sh: 训练扩展后的模型
├── MoD/
│ ├── freeze_sft.sh: 冻结微调模型,仅更新 MoD 路由器
│ └── sft.sh: 微调国防部模型
└── fsdp_qlora/
└── sft.sh: 使用 FSDP+QLoRA 微调量化模型
```

View File

@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type freeze \
--name_module_trainable router \
--output_dir ../../../saves/TinyLlama/TinyLlama-1.1B-Chat-v1.0/sft \
--mixture_of_depths convert \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -0,0 +1,32 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--output_dir ../../../saves/TinyLlama/TinyLlama-1.1B-Chat-v1.0/sft \
--mixture_of_depths convert \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -69,6 +69,10 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
)
mixture_of_depths: Optional[Literal["convert", "continue"]] = field(
default=None,
metadata={"help": "Whether or not to use MoD in the model."},
)
use_unsloth: bool = field(
default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},

View File

@ -82,6 +82,9 @@ def _check_extra_dependencies(
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
if model_args.mixture_of_depths:
require_version("mixture-of-depth", "To fix: pip install mixture-of-depth")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.3.3", "To fix: pip install vllm>=0.3.3")

View File

@ -69,6 +69,8 @@ def init_adapter(
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
elif ".1." in name: # here since MoD starts from layer 1
freeze_modules.add(name.split(".1.")[-1].split(".")[0])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:

View File

@ -71,6 +71,12 @@ def load_model(
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
model = None
if model_args.mixture_of_depths == 'continue':
from MoD import AutoMoDModelForCausalLM
model = AutoMoDModelForCausalLM.from_pretrained(model_args.model_name_or_path, config=config)
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")
if is_trainable and model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore
@ -100,6 +106,13 @@ def load_model(
init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(**init_kwargs)
if model_args.mixture_of_depths == 'convert':
from MoD import convert_hf_model
if model.config.model_type == 'qwen2':
RuntimeError("Qwen models are not supported for MoD training.")
model = convert_hf_model(model)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)