support docker-npu-[amd64|arm64] build

This commit is contained in:
fanjunliang 2024-06-27 15:21:55 +08:00
parent 8096f94a7d
commit bdda0827b3
4 changed files with 13 additions and 6 deletions

View File

@ -465,7 +465,7 @@ For Ascend NPU users:
```bash
# Choose docker image upon your environment
docker build -f ./docker/docker-npu/Dockerfile \
docker build --platform linux/arm64 -f ./docker/docker-npu/Dockerfile \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .

View File

@ -465,7 +465,7 @@ docker exec -it llamafactory bash
```bash
# 根据您的环境选择镜像
docker build -f ./docker/docker-npu/Dockerfile \
docker build --platform linux/arm64 -f ./docker/docker-npu/Dockerfile \
--build-arg INSTALL_DEEPSPEED=false \
--build-arg PIP_INDEX=https://pypi.org/simple \
-t llamafactory:latest .

View File

@ -1,10 +1,11 @@
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
FROM --platform=$TARGETPLATFORM cosdt/cann:8.0.rc1-910b-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive
# Define installation arguments
ARG TARGETPLATFORM
ARG INSTALL_DEEPSPEED=false
ARG PIP_INDEX=https://pypi.org/simple
ARG EXTRA_INDEX=https://download.pytorch.org/whl/cpu
@ -15,7 +16,6 @@ WORKDIR /app
# Install the requirements
COPY requirements.txt /app
RUN pip config set global.index-url $PIP_INDEX && \
pip config set global.extra-index-url $EXTRA_INDEX && \
pip install --upgrade pip && \
pip install -r requirements.txt
@ -23,7 +23,13 @@ RUN pip config set global.index-url $PIP_INDEX && \
COPY . /app
# Install the LLaMA Factory
RUN EXTRA_PACKAGES="torch-npu,metrics"; \
RUN EXTRA_PACKAGES="metrics"; \
if [ "$TARGETPLATFORM" == "linux/arm64" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},torch-npu-arm64"; \
else \
pip config set global.extra-index-url $EXTRA_INDEX; \
EXTRA_PACKAGES="${EXTRA_PACKAGES},torch-npu-amd64"; \
fi; \
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \

View File

@ -35,7 +35,8 @@ def get_requires():
extra_require = {
"torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"],
"torch-npu-arm64": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"torch-npu-amd64": ["torch==2.1.0+cpu", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],