Use official Nvidia base image

Note that the flash-attn library is installed in this image and the qwen model will use it automatically.
However, if the the host machine's GPU is not compatible with the library, an exception will be raised during the training process as follows:
FlashAttention only supports Ampere GPUs or newer.
So if the --flash_attn flag is not set, an additional patch for the qwen model's config is necessary to set the default value of use_flash_attn from "auto" to False.
This commit is contained in:
S3Studio 2024-03-14 18:03:33 +08:00 committed by liuzhao2
parent 6a5693d11d
commit e75407febd
2 changed files with 4 additions and 1 deletions

View File

@ -1,4 +1,4 @@
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
FROM nvcr.io/nvidia/pytorch:24.01-py3
WORKDIR /app

View File

@ -283,6 +283,9 @@ def patch_config(
setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, init_kwargs)
if getattr(config, "model_type", None) == "qwen" and init_kwargs["attn_implementation"] != "flash_attention_2":
config.use_flash_attn = False
_configure_rope(config, model_args, is_trainable)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, init_kwargs)