diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index dfd90936..9021d277 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -28,7 +28,13 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: +def configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> None: + if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention + logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + if model_args.flash_attn == "auto": return diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 24cd2601..4eae0bb4 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -67,7 +67,7 @@ def patch_config( use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] torch.npu.set_compile_mode(jit_compile=use_jit_compile) - configure_attn_implementation(config, model_args) + configure_attn_implementation(config, model_args, is_trainable) configure_rope(config, model_args, is_trainable) configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index f33c37ee..d4832dd3 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -54,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: num_train_epochs = gr.Textbox(value="3.0") max_grad_norm = gr.Textbox(value="1.0") max_samples = gr.Textbox(value="100000") - compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16") + compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16") input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type}) elem_dict.update(