fix layer norm dtype
This commit is contained in:
parent
b0b0138e1d
commit
84b7486885
|
@ -2,7 +2,7 @@ IGNORE_INDEX = -100
|
||||||
|
|
||||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
|
|
@ -19,21 +19,6 @@ except ImportError:
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return (self.weight * hidden_states).to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaShiftShortAttention(LlamaAttention):
|
class LlamaShiftShortAttention(LlamaAttention):
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -162,6 +147,14 @@ class LlamaFlashAttention2(LlamaAttention):
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# cast to half precision
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||||
|
query_states = query_states.to(torch.float16)
|
||||||
|
key_states = key_states.to(torch.float16)
|
||||||
|
value_states = value_states.to(torch.float16)
|
||||||
|
|
||||||
if getattr(self, "num_key_value_groups"):
|
if getattr(self, "num_key_value_groups"):
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
|
@ -67,6 +67,10 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
|
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={"help": "Data type of the layer norm weights."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
|
|
@ -128,10 +128,6 @@ def load_model_and_tokenizer(
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
|
||||||
# Fix RMSNorm in fp32 weight (https://github.com/huggingface/transformers/pull/23535)
|
|
||||||
if getattr(config, "model_type", None) == "llama":
|
|
||||||
LlamaModule.LlamaRMSNorm = LlamaPatches.LlamaRMSNorm
|
|
||||||
|
|
||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
|
@ -205,7 +201,8 @@ def load_model_and_tokenizer(
|
||||||
tokenizer.__class__.register_for_auto_class()
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
if is_trainable:
|
||||||
|
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
model = model.train() if is_trainable else model.eval()
|
model = model.train() if is_trainable else model.eval()
|
||||||
|
|
||||||
|
|
|
@ -226,6 +226,17 @@ def get_train_args(
|
||||||
else:
|
else:
|
||||||
model_args.compute_dtype = _infer_dtype()
|
model_args.compute_dtype = _infer_dtype()
|
||||||
|
|
||||||
|
if model_args.layernorm_dtype == "bf16":
|
||||||
|
if not is_bf16_available:
|
||||||
|
raise ValueError("Current device does not support bf16 type.")
|
||||||
|
model_args.layernorm_dtype = torch.bfloat16
|
||||||
|
elif model_args.layernorm_dtype == "fp16":
|
||||||
|
model_args.layernorm_dtype = torch.float16
|
||||||
|
elif model_args.layernorm_dtype == "fp32":
|
||||||
|
model_args.layernorm_dtype = torch.float32
|
||||||
|
else:
|
||||||
|
model_args.layernorm_dtype = model_args.compute_dtype
|
||||||
|
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
|
|
|
@ -31,6 +31,7 @@ def find_all_linear_modules(
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
layernorm_dtype: torch.dtype,
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
output_layer_name: Optional[str] = "lm_head",
|
output_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
|
@ -45,7 +46,7 @@ def prepare_model_for_training(
|
||||||
"""
|
"""
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(layernorm_dtype)
|
||||||
|
|
||||||
if use_gradient_checkpointing:
|
if use_gradient_checkpointing:
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
|
Loading…
Reference in New Issue