Merge pull request #4446 from stceum/bug-fix

Bug Fix: `off` is parsed as `False` in yaml file
This commit is contained in:
hoshi-hiyouga 2024-06-24 21:41:28 +08:00 committed by GitHub
commit cc452c32c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 4 deletions

View File

@ -97,7 +97,7 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
default="auto", default="auto",
metadata={"help": "Enable FlashAttention for faster training and inference."}, metadata={"help": "Enable FlashAttention for faster training and inference."},
) )

View File

@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
if model_args.flash_attn == "auto": if model_args.flash_attn == "auto":
return return
elif model_args.flash_attn == "off": elif model_args.flash_attn == "disabled":
requested_attn_implementation = "eager" requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":

View File

@ -29,7 +29,7 @@ INFER_ARGS = {
def test_attention(): def test_attention():
attention_available = ["off"] attention_available = ["disabled"]
if is_torch_sdpa_available(): if is_torch_sdpa_available():
attention_available.append("sdpa") attention_available.append("sdpa")
@ -37,7 +37,7 @@ def test_attention():
attention_available.append("fa2") attention_available.append("fa2")
llama_attention_classes = { llama_attention_classes = {
"off": "LlamaAttention", "disabled": "LlamaAttention",
"sdpa": "LlamaSdpaAttention", "sdpa": "LlamaSdpaAttention",
"fa2": "LlamaFlashAttention2", "fa2": "LlamaFlashAttention2",
} }