From e75407febdec086f2bdca723a7f69a92b3b1d63f Mon Sep 17 00:00:00 2001 From: S3Studio Date: Thu, 14 Mar 2024 18:03:33 +0800 Subject: [PATCH] Use official Nvidia base image Note that the flash-attn library is installed in this image and the qwen model will use it automatically. However, if the the host machine's GPU is not compatible with the library, an exception will be raised during the training process as follows: FlashAttention only supports Ampere GPUs or newer. So if the --flash_attn flag is not set, an additional patch for the qwen model's config is necessary to set the default value of use_flash_attn from "auto" to False. --- Dockerfile | 2 +- src/llmtuner/model/patcher.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 155b86d4..c3d231b5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04 +FROM nvcr.io/nvidia/pytorch:24.01-py3 WORKDIR /app diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index bd484052..210044f2 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -283,6 +283,9 @@ def patch_config( setattr(config, dtype_name, model_args.compute_dtype == dtype) _configure_attn_implementation(model_args, init_kwargs) + if getattr(config, "model_type", None) == "qwen" and init_kwargs["attn_implementation"] != "flash_attention_2": + config.use_flash_attn = False + _configure_rope(config, model_args, is_trainable) _configure_longlora(config, model_args, is_trainable) _configure_quantization(config, tokenizer, model_args, init_kwargs)