fix flashattn warning

This commit is contained in:
hiyouga 2023-11-10 18:34:54 +08:00
parent a0c31c68c4
commit 4bd8e3906d
2 changed files with 10 additions and 4 deletions

View File

@ -5,11 +5,14 @@ from typing import Optional, Tuple
from transformers.utils import logging from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
is_flash_attn_2_available = False
try: try:
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError: except ImportError:
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.") is_flash_attn_2_available = False
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -123,9 +123,12 @@ def load_model_and_tokenizer(
# Set FlashAttention-2 # Set FlashAttention-2
if model_args.flash_attn: if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama": if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 if LlamaPatches.is_flash_attn_2_available:
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
logger.info("Using FlashAttention-2 for faster training and inference.") LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) in ["qwen", "Yi"]: elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
logger.info("Current model automatically enables FlashAttention if installed.") logger.info("Current model automatically enables FlashAttention if installed.")
else: else: