From 4bd8e3906d09bf6ec4b8f6b553a347fca9db4f80 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 10 Nov 2023 18:34:54 +0800 Subject: [PATCH] fix flashattn warning --- src/llmtuner/extras/patches/llama_patch.py | 5 ++++- src/llmtuner/tuner/core/loader.py | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index a8473311..bf3e5d57 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -5,11 +5,14 @@ from typing import Optional, Tuple from transformers.utils import logging from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv +is_flash_attn_2_available = False + try: from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore from flash_attn.bert_padding import pad_input, unpad_input # type: ignore + is_flash_attn_2_available = True except ImportError: - print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.") + is_flash_attn_2_available = False logger = logging.get_logger(__name__) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 663f60d9..34bc2a6e 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -123,9 +123,12 @@ def load_model_and_tokenizer( # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": - LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 - LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask - logger.info("Using FlashAttention-2 for faster training and inference.") + if LlamaPatches.is_flash_attn_2_available: + LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 + LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask + logger.info("Using FlashAttention-2 for faster training and inference.") + else: + logger.warning("FlashAttention-2 is not installed.") elif getattr(config, "model_type", None) in ["qwen", "Yi"]: logger.info("Current model automatically enables FlashAttention if installed.") else: