fix ci
This commit is contained in:
parent
aa2578bea0
commit
7f20e4722a
|
@ -30,10 +30,10 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[torch,metrics,quality]
|
||||
python -m pip install .[torch,dev]
|
||||
- name: Check quality
|
||||
run: |
|
||||
make style && make quality
|
||||
make style && make quality
|
||||
|
||||
pytest:
|
||||
needs: check_code_quality
|
||||
|
@ -53,7 +53,7 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[torch,metrics,quality]
|
||||
python -m pip install .[torch,dev]
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
|
|
2
setup.py
2
setup.py
|
@ -33,7 +33,7 @@ extra_require = {
|
|||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"quality": ["ruff"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -23,13 +23,15 @@ def test_attention():
|
|||
"fa2": "LlamaFlashAttention2",
|
||||
}
|
||||
for requested_attention in attention_available:
|
||||
model_args, _, finetuning_args, _ = get_infer_args({
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama2",
|
||||
"flash_attn": requested_attention,
|
||||
})
|
||||
model_args, _, finetuning_args, _ = get_infer_args(
|
||||
{
|
||||
"model_name_or_path": TINY_LLAMA,
|
||||
"template": "llama2",
|
||||
"flash_attn": requested_attention,
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(model_args)
|
||||
model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
||||
assert module.__class__.__name__ == llama_attention_classes[requested_attention]
|
Loading…
Reference in New Issue