loose gemma2 attention
This commit is contained in:
parent
0e0d69b77c
commit
2f4b89ace1
|
@ -32,8 +32,14 @@ def configure_attn_implementation(
|
||||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||||
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
if model_args.flash_attn == "auto":
|
||||||
model_args.flash_attn = "disabled"
|
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||||
|
model_args.flash_attn = "disabled"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. "
|
||||||
|
"Will proceed at your own risk.".format(model_args.flash_attn)
|
||||||
|
)
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
|
||||||
# This code is inspired by the HuggingFace's transformers library.
|
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
Loading…
Reference in New Issue