This commit is contained in:
hiyouga 2024-06-08 01:57:36 +08:00
parent aa2578bea0
commit 7f20e4722a
3 changed files with 12 additions and 10 deletions

View File

@ -30,10 +30,10 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install .[torch,metrics,quality] python -m pip install .[torch,dev]
- name: Check quality - name: Check quality
run: | run: |
make style && make quality make style && make quality
pytest: pytest:
needs: check_code_quality needs: check_code_quality
@ -53,7 +53,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install .[torch,metrics,quality] python -m pip install .[torch,dev]
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test

View File

@ -33,7 +33,7 @@ extra_require = {
"aqlm": ["aqlm[gpu]>=1.1.0"], "aqlm": ["aqlm[gpu]>=1.1.0"],
"qwen": ["transformers_stream_generator"], "qwen": ["transformers_stream_generator"],
"modelscope": ["modelscope"], "modelscope": ["modelscope"],
"quality": ["ruff"], "dev": ["ruff", "pytest"],
} }

View File

@ -23,13 +23,15 @@ def test_attention():
"fa2": "LlamaFlashAttention2", "fa2": "LlamaFlashAttention2",
} }
for requested_attention in attention_available: for requested_attention in attention_available:
model_args, _, finetuning_args, _ = get_infer_args({ model_args, _, finetuning_args, _ = get_infer_args(
"model_name_or_path": TINY_LLAMA, {
"template": "llama2", "model_name_or_path": TINY_LLAMA,
"flash_attn": requested_attention, "template": "llama2",
}) "flash_attn": requested_attention,
}
)
tokenizer = load_tokenizer(model_args) tokenizer = load_tokenizer(model_args)
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args) model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
for module in model.modules(): for module in model.modules():
if "Attention" in module.__class__.__name__: if "Attention" in module.__class__.__name__:
assert module.__class__.__name__ == llama_attention_classes[requested_attention] assert module.__class__.__name__ == llama_attention_classes[requested_attention]