diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 80d9d4b8..4bed7e21 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -35,7 +35,7 @@ def configure_attn_implementation( if model_args.flash_attn == "auto": logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") model_args.flash_attn = "disabled" - else: + elif model_args.flash_attn != "disabled": logger.warning( "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. " "Will proceed at your own risk.".format(model_args.flash_attn)