forked from p04798526/LLaMA-Factory-Mirror
bf16 by default, gemma2 attns
Gemma2 finetuning cannot work until merging https://github.com/huggingface/transformers/pull/31674
This commit is contained in:
parent
64f4337dac
commit
4d35e218b1
|
@ -28,7 +28,13 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
def configure_attn_implementation(
|
||||||
|
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||||
|
) -> None:
|
||||||
|
if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention
|
||||||
|
logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.")
|
||||||
|
model_args.flash_attn = "disabled"
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ def patch_config(
|
||||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||||
|
|
||||||
configure_attn_implementation(config, model_args)
|
configure_attn_implementation(config, model_args, is_trainable)
|
||||||
configure_rope(config, model_args, is_trainable)
|
configure_rope(config, model_args, is_trainable)
|
||||||
configure_longlora(config, model_args, is_trainable)
|
configure_longlora(config, model_args, is_trainable)
|
||||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||||
|
|
|
@ -54,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||||
num_train_epochs = gr.Textbox(value="3.0")
|
num_train_epochs = gr.Textbox(value="3.0")
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
max_samples = gr.Textbox(value="100000")
|
max_samples = gr.Textbox(value="100000")
|
||||||
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
|
||||||
|
|
||||||
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
|
|
Loading…
Reference in New Issue