Update test_attention.py
This commit is contained in:
parent
3ed063f281
commit
a9b3d91952
|
@ -29,7 +29,7 @@ INFER_ARGS = {
|
|||
|
||||
|
||||
def test_attention():
|
||||
attention_available = ["off"]
|
||||
attention_available = ["disabled"]
|
||||
if is_torch_sdpa_available():
|
||||
attention_available.append("sdpa")
|
||||
|
||||
|
@ -37,7 +37,7 @@ def test_attention():
|
|||
attention_available.append("fa2")
|
||||
|
||||
llama_attention_classes = {
|
||||
"off": "LlamaAttention",
|
||||
"disabled": "LlamaAttention",
|
||||
"sdpa": "LlamaSdpaAttention",
|
||||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue