diff --git a/README.md b/README.md index 443c8cf7..45ac23d8 100644 --- a/README.md +++ b/README.md @@ -444,6 +444,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \ --build-arg INSTALL_BNB=false \ --build-arg INSTALL_VLLM=false \ --build-arg INSTALL_DEEPSPEED=false \ + --build-arg INSTALL_FLASH_ATTN=false \ --build-arg PIP_INDEX=https://pypi.org/simple \ -t llamafactory:latest . diff --git a/README_zh.md b/README_zh.md index d5172a7d..c5fd4f69 100644 --- a/README_zh.md +++ b/README_zh.md @@ -444,6 +444,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \ --build-arg INSTALL_BNB=false \ --build-arg INSTALL_VLLM=false \ --build-arg INSTALL_DEEPSPEED=false \ + --build-arg INSTALL_FLASH_ATTN=false \ --build-arg PIP_INDEX=https://pypi.org/simple \ -t llamafactory:latest . diff --git a/docker/docker-cuda/Dockerfile b/docker/docker-cuda/Dockerfile index 827b7b3c..44aaf538 100644 --- a/docker/docker-cuda/Dockerfile +++ b/docker/docker-cuda/Dockerfile @@ -6,6 +6,7 @@ FROM nvcr.io/nvidia/pytorch:24.02-py3 ARG INSTALL_BNB=false ARG INSTALL_VLLM=false ARG INSTALL_DEEPSPEED=false +ARG INSTALL_FLASH_ATTN=false ARG PIP_INDEX=https://pypi.org/simple # Set the working directory @@ -35,6 +36,13 @@ RUN EXTRA_PACKAGES="metrics"; \ pip install -e .[$EXTRA_PACKAGES] && \ pip uninstall -y transformer-engine flash-attn +# Rebuild flash-attn +RUN if [ "$INSTALL_FLASH_ATTN" = "true" ]; then \ + ninja --version || \ + (pip uninstall -y ninja && pip install ninja) && \ + MAX_JOBS=4 pip install --no-cache-dir flash-attn --no-build-isolation \ + fi; + # Set up volumes VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ] diff --git a/docker/docker-cuda/docker-compose.yml b/docker/docker-cuda/docker-compose.yml index e2d1a5ad..4ccb0c04 100644 --- a/docker/docker-cuda/docker-compose.yml +++ b/docker/docker-cuda/docker-compose.yml @@ -7,6 +7,7 @@ services: INSTALL_BNB: false INSTALL_VLLM: false INSTALL_DEEPSPEED: false + INSTALL_FLASH_ATTN: false PIP_INDEX: https://pypi.org/simple container_name: llamafactory volumes: