From 72bc8f01111ad69b92a647b54b4af988515d9c34 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 27 Aug 2024 11:20:14 +0800 Subject: [PATCH] support liger kernel --- README.md | 8 ++-- README_zh.md | 8 ++-- src/llamafactory/hparams/model_args.py | 4 ++ src/llamafactory/hparams/parser.py | 3 ++ .../model/model_utils/liger_kernel.py | 48 +++++++++++++++++++ src/llamafactory/model/patcher.py | 2 + src/llamafactory/webui/components/top.py | 2 +- src/llamafactory/webui/runner.py | 1 + 8 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 src/llamafactory/model/model_utils/liger_kernel.py diff --git a/README.md b/README.md index d034e574..25aa7881 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Choose your path: - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Advanced algorithms**: GaLore, BAdam, Adam-mini, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. -- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. +- **Practical tricks**: FlashAttention-2, Unsloth, Liger Kernel, RoPE scaling, NEFTune and rsLoRA. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. @@ -72,14 +72,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `use_liger_kernel: true` for efficient training. + [24/08/09] We support **[Adam-mini](https://arxiv.org/abs/2406.16793)** optimizer. See [examples](examples/README.md) for usage. Thank [@relic-yuexi](https://github.com/relic-yuexi)'s PR. [24/07/04] We support [contamination-free packed training](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing). Use `neat_packing: true` to activate it. Thank [@chuan298](https://github.com/chuan298)'s PR. -[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. -
Full Changelog +[24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage. + [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. diff --git a/README_zh.md b/README_zh.md index 6b8335ca..37bc2f5f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -52,7 +52,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **先进算法**:GaLore、BAdam、Adam-mini、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 -- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 +- **实用技巧**:FlashAttention-2、Unsloth、Liger Kernel、RoPE scaling、NEFTune 和 rsLoRA。 - **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 @@ -73,14 +73,16 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `use_liger_kernel: true` 来加速训练。 + [24/08/09] 我们支持了 **[Adam-mini](https://arxiv.org/abs/2406.16793)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。感谢 [@relic-yuexi](https://github.com/relic-yuexi) 的 PR。 [24/07/04] 我们支持了[无污染打包训练](https://github.com/MeetKai/functionary/tree/main/functionary/train/packing)。请使用 `neat_packing: true` 参数。感谢 [@chuan298](https://github.com/chuan298) 的 PR。 -[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 -
展开日志 +[24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)。 + [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 4ac47512..f209e338 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -117,6 +117,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, ) + use_liger_kernel: bool = field( + default=False, + metadata={"help": "Whether or not to enable liger kernel for faster training."}, + ) visual_inputs: bool = field( default=False, metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 9907aa4f..bea3d650 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -116,6 +116,9 @@ def _check_extra_dependencies( if model_args.use_unsloth: require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") + if model_args.use_liger_kernel: + require_version("liger-kernel", "To fix: pip install liger-kernel") + if model_args.mixture_of_depths is not None: require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py new file mode 100644 index 00000000..61de0be0 --- /dev/null +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -0,0 +1,48 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.use_liger_kernel: + return + + if getattr(config, "model_type", None) == "gemma": + from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel + elif getattr(config, "model_type", None) == "llama": + from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel + elif getattr(config, "model_type", None) == "mistral": + from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel + elif getattr(config, "model_type", None) == "mixtral": + from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel + elif getattr(config, "model_type", None) == "qwen2": + from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel + else: + logger.warning("Current model does not support liger kernel.") + return + + apply_liger_kernel() + logger.info("Liger kernel has been applied to the model.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index a99d38e0..a278c154 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -27,6 +27,7 @@ from ..extras.misc import infer_optim_dtype from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training from .model_utils.embedding import resize_embedding_layer +from .model_utils.liger_kernel import configure_liger_kernel from .model_utils.longlora import configure_longlora from .model_utils.moe import add_z3_leaf_module, configure_moe from .model_utils.packing import configure_packing @@ -70,6 +71,7 @@ def patch_config( configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) + configure_liger_kernel(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index ac1a7a42..f2630c7b 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -47,7 +47,7 @@ def create_top() -> Dict[str, "Component"]: quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=1) template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=1) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=2) - booster = gr.Radio(choices=["auto", "flashattn2", "unsloth"], value="auto", scale=2) + booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=3) visual_inputs = gr.Checkbox(scale=1) model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 4ce35a02..72176986 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -115,6 +115,7 @@ class Runner: rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), + use_liger_kernel=(get("top.booster") == "liger_kernel"), visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")),