diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 7dee827c..da53baa2 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -36,13 +36,14 @@ def configure_attn_implementation( if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if is_flash_attn_2_available(): require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") + require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0") logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") model_args.flash_attn = "fa2" else: logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") model_args.flash_attn = "disabled" elif model_args.flash_attn == "sdpa": - raise ValueError("Gemma-2 should use soft-capping attention, while the SDPA attention is not compatible.") + logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") if model_args.flash_attn == "auto": return