diff --git a/setup.py b/setup.py index 594070cd..dd333c48 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,8 @@ def get_requires(): extra_require = { "torch": ["torch>=1.13.1"], - "torch-npu-arm64": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], - "torch-npu-amd64": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"], + "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], + "torch-npu-amd": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],