From 07737a3d2d026c973ab964f948953d6ce0e1f2a9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 24 Apr 2024 02:18:44 +0800 Subject: [PATCH] reenable sdpa and fast tok by default --- README.md | 4 +-- README_zh.md | 4 +-- requirements.txt | 1 + src/llmtuner/extras/packages.py | 19 ++++++++--- src/llmtuner/hparams/model_args.py | 8 ++--- src/llmtuner/model/patcher.py | 49 ++++++++++++++++++++++------ src/llmtuner/webui/components/top.py | 2 +- src/llmtuner/webui/runner.py | 4 +-- 8 files changed, 64 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 0bf9f731..970dd8fc 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,6 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/04/21] We supported **[Mixture-of-Depths](https://arxiv.org/abs/2404.02258)** according to [AstraMindAI's implementation](https://github.com/astramind-ai/Mixture-of-depths). See `examples/extras/mod` for usage. -[24/04/19] We supported **Meta Llama 3** model series. - [24/04/16] We supported **[BAdam](https://arxiv.org/abs/2404.02827)**. See `examples/extras/badam` for usage. [24/04/16] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s long-sequence training (Llama-2-7B-56k within 24GB). It achieves **117%** speed and **50%** memory compared with FlashAttention-2, more benchmarks can be found in [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison). @@ -112,7 +110,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. -[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. +[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn fa2` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. diff --git a/README_zh.md b/README_zh.md index 69ba2562..583c89ca 100644 --- a/README_zh.md +++ b/README_zh.md @@ -72,8 +72,6 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [24/04/21] 我们基于 [AstraMindAI 的仓库](https://github.com/astramind-ai/Mixture-of-depths)支持了 **[混合深度训练](https://arxiv.org/abs/2404.02258)**。详细用法请参照 `examples/extras/mod`。 -[24/04/19] 我们支持了 **Meta Llama 3** 系列模型。 - [24/04/16] 我们支持了 **[BAdam](https://arxiv.org/abs/2404.02827)**。详细用法请参照 `examples/extras/badam`。 [24/04/16] 我们支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的长序列训练(24GB 可训练 Llama-2-7B-56k)。该方法相比 FlashAttention-2 提供了 **117%** 的训练速度和 **50%** 的显存节约。更多数据请见[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 @@ -112,7 +110,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 -[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。 +[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn fa2` 参数以启用 FlashAttention-2。 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 diff --git a/requirements.txt b/requirements.txt index 3928d28d..ecba3ce1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,3 +15,4 @@ fastapi sse-starlette matplotlib fire +packaging diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index 8494cb2c..aeeba084 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -1,16 +1,23 @@ import importlib.metadata import importlib.util +from typing import TYPE_CHECKING + +from packaging import version + + +if TYPE_CHECKING: + from packaging.version import Version def _is_package_available(name: str) -> bool: return importlib.util.find_spec(name) is not None -def _get_package_version(name: str) -> str: +def _get_package_version(name: str) -> "Version": try: - return importlib.metadata.version(name) + return version.parse(importlib.metadata.version(name)) except Exception: - return "0.0.0" + return version.parse("0.0.0") def is_fastapi_availble(): @@ -18,7 +25,7 @@ def is_fastapi_availble(): def is_flash_attn2_available(): - return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") + return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0") def is_galore_available(): @@ -49,6 +56,10 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") +def is_sdpa_available(): + return _get_package_version("torch") > version.parse("2.1.1") + + def is_starlette_available(): return _is_package_available("sse_starlette") diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 0e42033f..eb6366d9 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -22,7 +22,7 @@ class ModelArguments: metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, ) use_fast_tokenizer: bool = field( - default=False, + default=True, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, ) resize_vocab: bool = field( @@ -61,9 +61,9 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: bool = field( - default=False, - metadata={"help": "Enable FlashAttention for faster training."}, + flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, ) shift_attn: bool = field( default=False, diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 53616dd9..6c79992a 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,7 +15,7 @@ from transformers.utils.versions import require_version from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES from ..extras.logging import get_logger from ..extras.misc import get_current_device, infer_optim_dtype -from ..extras.packages import is_flash_attn2_available +from ..extras.packages import is_flash_attn2_available, is_sdpa_available from ..extras.patches.llama_patch import apply_llama_patch from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable @@ -62,18 +62,45 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod def _configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - if model_args.flash_attn: - if not is_flash_attn2_available(): - logger.warning("FlashAttention2 is not installed.") + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "off": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_sdpa_available(): + logger.warning("Torch>=2.1.1 is required for SDPA attention.") return - logger.info("Using FlashAttention-2 for faster training and inference.") - if getattr(config, "model_type", None) == "internlm2": # special case for custom models - setattr(config, "attn_implementation", "flash_attention_2") - else: - setattr(config, "_attn_implementation", "flash_attention_2") + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" else: - setattr(config, "_attn_implementation", "eager") + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def _print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla Attention implementation.") def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -365,6 +392,8 @@ def patch_model( add_z3_leaf_module(model, Qwen2MoeSparseMoeBlock) + _print_attn_implementation(model.config) + try: model.add_model_tags(["llama-factory"]) except Exception: diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 6cbf6e0d..c67d7cc5 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -33,7 +33,7 @@ def create_top() -> Dict[str, "Component"]: quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") template = gr.Dropdown(choices=list(templates.keys()), value="default") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") - booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none") + booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none") model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( get_model_path, [model_name], [model_path], queue=False diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index ec493c96..b64a015c 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -67,7 +67,7 @@ class Runner: if not model_path: return ALERTS["err_no_path"][lang] - if len(dataset) == 0: + if not dataset: return ALERTS["err_no_dataset"][lang] if not from_preview and self.demo_mode: @@ -122,7 +122,7 @@ class Runner: quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, template=get("top.template"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, - flash_attn=(get("top.booster") == "flashattn"), + flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")),