Merge pull request #4781 from hzhaoy/fix-dockerfile-cuda
Fix cuda Dockerfile
This commit is contained in:
commit
5da54deb50
|
@ -5,6 +5,7 @@ FROM nvcr.io/nvidia/pytorch:24.02-py3
|
|||
# Define environments
|
||||
ENV MAX_JOBS=4
|
||||
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
# Define installation arguments
|
||||
ARG INSTALL_BNB=false
|
||||
|
@ -23,13 +24,6 @@ RUN pip config set global.index-url "$PIP_INDEX" && \
|
|||
python -m pip install --upgrade pip && \
|
||||
python -m pip install -r requirements.txt
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||
pip uninstall -y ninja && pip install ninja && \
|
||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||
fi
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app
|
||||
|
||||
|
@ -46,6 +40,13 @@ RUN EXTRA_PACKAGES="metrics"; \
|
|||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||
pip uninstall -y ninja && pip install ninja && \
|
||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||
fi
|
||||
|
||||
# Set up volumes
|
||||
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir))
|
||||
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
|
|
Loading…
Reference in New Issue