diff --git a/README.md b/README.md index beff2fb5..f3970d00 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ Choose your path: - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. -- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. +- **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -71,14 +71,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank @relic-yuexi for PR. + [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. -[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. -
Full Changelog +[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. + [24/05/20] We supported fine-tuning the **PaliGemma** series models. Note that the PaliGemma models are pre-trained models, you need to fine-tune them with `gemma` template for chat completion. [24/05/18] We supported **[KTO](https://arxiv.org/abs/2402.01306)** algorithm for preference learning. See [examples](examples/README.md) for usage. @@ -91,7 +93,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See [examples](examples/README.md) for usage. -[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See [examples](examples/README.md) for usage. +[24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)** optimizer. See [examples](examples/README.md) for usage. [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). @@ -103,7 +105,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See [examples](examples/README.md) for usage. -[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See [examples](examples/README.md) for usage. +[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** optimizer. See [examples](examples/README.md) for usage. [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `infer_backend: vllm` to enjoy **270%** inference speed. @@ -342,7 +344,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality +Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality > [!TIP] > Use `pip install --no-deps -e .` to resolve package conflicts. diff --git a/README_zh.md b/README_zh.md index 41e4acfe..3f93af5d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -49,7 +49,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 -- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 +- **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 - **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -71,14 +71,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 @relic-yuexi 的 PR。 + [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 -[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 -
展开日志 +[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 + [24/05/20] 我们支持了 **PaliGemma** 系列模型的微调。注意 PaliGemma 是预训练模型,你需要使用 `gemma` 模板进行微调使其获得对话能力。 [24/05/18] 我们支持了 **[KTO](https://arxiv.org/abs/2402.01306)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 @@ -91,7 +93,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 [examples](examples/README_zh.md)。 -[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 [examples](examples/README_zh.md)。 +[24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 @@ -103,7 +105,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 [24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 [examples](examples/README_zh.md)。 -[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 [examples](examples/README_zh.md)。 +[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `infer_backend: vllm` 来获得 **270%** 的推理速度。 @@ -342,7 +344,7 @@ cd LLaMA-Factory pip install -e ".[torch,metrics]" ``` -可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality +可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality > [!TIP] > 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。 diff --git a/examples/README.md b/examples/README.md index d5aca5ad..f396ee40 100644 --- a/examples/README.md +++ b/examples/README.md @@ -189,6 +189,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` +#### Full-Parameter Fine-Tuning using Adam-mini + +```bash +llamafactory-cli train examples/extras/adam_mini/llama3_full_sft.yaml +``` + #### LoRA+ Fine-Tuning ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index d96bf882..e179d77d 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -189,6 +189,12 @@ llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml ``` +#### 使用 Adam-mini 进行全参数训练 + +```bash +llamafactory-cli train examples/extras/adam_mini/llama3_full_sft.yaml +``` + #### LoRA+ 微调 ```bash diff --git a/examples/extras/adam_mini/llama3_full_sft.yaml b/examples/extras/adam_mini/llama3_full_sft.yaml new file mode 100644 index 00000000..c0c4740b --- /dev/null +++ b/examples/extras/adam_mini/llama3_full_sft.yaml @@ -0,0 +1,39 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +### method +stage: sft +do_train: true +finetuning_type: full +use_adam_mini: true + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/examples/extras/galore/llama3_full_sft.yaml b/examples/extras/galore/llama3_full_sft.yaml index efd2c124..2911496d 100644 --- a/examples/extras/galore/llama3_full_sft.yaml +++ b/examples/extras/galore/llama3_full_sft.yaml @@ -34,6 +34,7 @@ num_train_epochs: 3.0 lr_scheduler_type: cosine warmup_ratio: 0.1 pure_bf16: true +ddp_timeout: 180000000 ### eval val_size: 0.1 diff --git a/examples/extras/llama_pro/expand.sh b/examples/extras/llama_pro/expand.sh index e0d41c7b..9f3c013c 100644 --- a/examples/extras/llama_pro/expand.sh +++ b/examples/extras/llama_pro/expand.sh @@ -2,5 +2,5 @@ python scripts/llama_pro.py \ --model_name_or_path meta-llama/Meta-Llama-3-8B-Instruct \ - --output_dir models/llama3-8b-instruct-pro \ + --output_dir models/llama3-8b-pro \ --num_expand 8 diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml index 5e7e90bb..07f3e1ca 100644 --- a/examples/extras/llama_pro/llama3_freeze_sft.yaml +++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml @@ -1,5 +1,5 @@ ### model -model_name_or_path: models/llama3-8b-instruct-pro +model_name_or_path: models/llama3-8b-pro ### method stage: sft @@ -18,7 +18,7 @@ overwrite_cache: true preprocessing_num_workers: 16 ### output -output_dir: saves/llama3-8b-instruct-pro/freeze/sft +output_dir: saves/llama3-8b-pro/freeze/sft logging_steps: 10 save_steps: 500 plot_loss: true diff --git a/setup.py b/setup.py index d43c311c..745d462f 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ extra_require = { "vllm": ["vllm>=0.4.3"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], + "adam-mini": ["adam-mini"], "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], "dev": ["ruff", "pytest"], diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 0edae2d4..ba1306e1 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -326,6 +326,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, ) + use_adam_mini: bool = field( + default=False, + metadata={"help": "Whether or not to use the Adam-mini optimizer."}, + ) freeze_vision_tower: bool = field( default=True, metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, @@ -342,10 +346,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) - use_adammini: bool = field( - default=False, - metadata={"help": "Whether or not to use AdamMini optimizer."}, - ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 72a18378..9907aa4f 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -128,6 +128,9 @@ def _check_extra_dependencies( if finetuning_args.use_badam: require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") + if finetuning_args.use_adam_mini: + require_version("adam-mini", "To fix: pip install adam-mini") + if finetuning_args.plot_loss: require_version("matplotlib", "To fix: pip install matplotlib") diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 6503cc42..2fa149a0 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names @@ -365,18 +366,16 @@ def _create_badam_optimizer( return optimizer -def _create_adammini_optimizer( + +def _create_adam_mini_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", - finetuning_args: "FinetuningArguments", ) -> "torch.optim.Optimizer": from adam_mini import Adam_mini - n_embd = model.config.hidden_size - n_head = model.config.num_attention_heads - n_query_groups = getattr(model.config, "num_key_value_heads", n_head) - - print("n_embd", n_embd, "n_head", n_head, "n_query_groups", n_query_groups) + hidden_size = getattr(model.config, "hidden_size", None) + num_q_head = getattr(model.config, "num_attention_heads", None) + num_kv_head = getattr(model.config, "num_key_value_heads", None) optimizer = Adam_mini( named_parameters=model.named_parameters(), @@ -384,14 +383,15 @@ def _create_adammini_optimizer( betas=(training_args.adam_beta1, training_args.adam_beta2), eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, - model_sharding=False, - dim=n_embd, - n_heads=n_head, - n_kv_heads=n_query_groups, + model_sharding=is_fsdp_enabled() or is_deepspeed_zero3_enabled(), + dim=hidden_size, + n_heads=num_q_head, + n_kv_heads=num_kv_head, ) - + logger.info("Using Adam-mini optimizer.") return optimizer + def create_custom_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", @@ -406,9 +406,9 @@ def create_custom_optimizer( if finetuning_args.use_badam: return _create_badam_optimizer(model, training_args, finetuning_args) - if finetuning_args.use_adammini: - return _create_adammini_optimizer(model, training_args, finetuning_args) - + if finetuning_args.use_adam_mini: + return _create_adam_mini_optimizer(model, training_args) + def create_custom_scheduler( training_args: "Seq2SeqTrainingArguments",