Bug Fix: `off` is parsed as `False` in yaml file, changed to `disabled` to avoid this.

This commit is contained in:
stceum 2024-06-24 20:39:20 +08:00
parent 5b897e7c35
commit 3ed063f281
3 changed files with 6 additions and 2 deletions

View File

@ -97,7 +97,7 @@ class ModelArguments:
default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
)
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."},
)

View File

@ -102,6 +102,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
# In case that `flash_attn` is set to `off` in the yaml file, and parsed as `False` afterwards.
if model_args.flash_attn == False:
raise ValueError("flash_attn should be \"disabled\", \"sdpa\", \"fa2\" or \"auto\".")
def _check_extra_dependencies(
model_args: "ModelArguments",

View File

@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
if model_args.flash_attn == "auto":
return
elif model_args.flash_attn == "off":
elif model_args.flash_attn == "disabled":
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":