This commit is contained in:
hiyouga 2024-07-14 10:56:45 +08:00
parent 2f6af73da2
commit d3c01552e0
1 changed files with 2 additions and 1 deletions

View File

@ -36,13 +36,14 @@ def configure_attn_implementation(
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available(): if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = "fa2"
else: else:
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
model_args.flash_attn = "disabled" model_args.flash_attn = "disabled"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.") logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return