From bdda0827b35cbb6005f10aa661fe6f3940b598d8 Mon Sep 17 00:00:00 2001 From: fanjunliang Date: Thu, 27 Jun 2024 15:21:55 +0800 Subject: [PATCH] support docker-npu-[amd64|arm64] build --- README.md | 2 +- README_zh.md | 2 +- docker/docker-npu/Dockerfile | 12 +++++++++--- setup.py | 3 ++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 4b42edd7..9c509ff0 100644 --- a/README.md +++ b/README.md @@ -465,7 +465,7 @@ For Ascend NPU users: ```bash # Choose docker image upon your environment -docker build -f ./docker/docker-npu/Dockerfile \ +docker build --platform linux/arm64 -f ./docker/docker-npu/Dockerfile \ --build-arg INSTALL_DEEPSPEED=false \ --build-arg PIP_INDEX=https://pypi.org/simple \ -t llamafactory:latest . diff --git a/README_zh.md b/README_zh.md index 3926c09d..c3fb6ecf 100644 --- a/README_zh.md +++ b/README_zh.md @@ -465,7 +465,7 @@ docker exec -it llamafactory bash ```bash # 根据您的环境选择镜像 -docker build -f ./docker/docker-npu/Dockerfile \ +docker build --platform linux/arm64 -f ./docker/docker-npu/Dockerfile \ --build-arg INSTALL_DEEPSPEED=false \ --build-arg PIP_INDEX=https://pypi.org/simple \ -t llamafactory:latest . diff --git a/docker/docker-npu/Dockerfile b/docker/docker-npu/Dockerfile index 0ec16107..8d80397e 100644 --- a/docker/docker-npu/Dockerfile +++ b/docker/docker-npu/Dockerfile @@ -1,10 +1,11 @@ # Use the Ubuntu 22.04 image with CANN 8.0.rc1 # More versions can be found at https://hub.docker.com/r/cosdt/cann/tags -FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04 +FROM --platform=$TARGETPLATFORM cosdt/cann:8.0.rc1-910b-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive # Define installation arguments +ARG TARGETPLATFORM ARG INSTALL_DEEPSPEED=false ARG PIP_INDEX=https://pypi.org/simple ARG EXTRA_INDEX=https://download.pytorch.org/whl/cpu @@ -15,7 +16,6 @@ WORKDIR /app # Install the requirements COPY requirements.txt /app RUN pip config set global.index-url $PIP_INDEX && \ - pip config set global.extra-index-url $EXTRA_INDEX && \ pip install --upgrade pip && \ pip install -r requirements.txt @@ -23,7 +23,13 @@ RUN pip config set global.index-url $PIP_INDEX && \ COPY . /app # Install the LLaMA Factory -RUN EXTRA_PACKAGES="torch-npu,metrics"; \ +RUN EXTRA_PACKAGES="metrics"; \ + if [ "$TARGETPLATFORM" == "linux/arm64" ]; then \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},torch-npu-arm64"; \ + else \ + pip config set global.extra-index-url $EXTRA_INDEX; \ + EXTRA_PACKAGES="${EXTRA_PACKAGES},torch-npu-amd64"; \ + fi; \ if [ "$INSTALL_DEEPSPEED" = "true" ]; then \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ fi; \ diff --git a/setup.py b/setup.py index 89301d1b..594070cd 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,8 @@ def get_requires(): extra_require = { "torch": ["torch>=1.13.1"], - "torch-npu": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"], + "torch-npu-arm64": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], + "torch-npu-amd64": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],