update loader

This commit is contained in:
hiyouga 2023-12-24 19:10:23 +08:00
parent e44b82ee24
commit 6629087e12
6 changed files with 67 additions and 68 deletions

View File

@ -145,13 +145,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
default="lora",
metadata={"help": "Which fine-tuning method to use."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
metadata={"help": "Whether or not to save the training loss curves."}
)
def __post_init__(self):

View File

@ -20,11 +20,11 @@ class ModelArguments:
)
use_fast_tokenizer: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}
)
resize_vocab: Optional[bool] = field(
default=False,
metadata={"help": "Whether to resize the tokenizer vocab and the embedding layers."}
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
)
split_special_tokens: Optional[bool] = field(
default=False,
@ -44,11 +44,11 @@ class ModelArguments:
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."}
metadata={"help": "Whether or not to use double quantization in int4 training."}
)
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
)
flash_attn: Optional[bool] = field(
default=False,
@ -60,7 +60,15 @@ class ModelArguments:
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use unsloth's optimization for LoRA training."}
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
)
disable_gradient_checkpointing: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
)
hf_hub_token: Optional[str] = field(
default=None,

View File

@ -8,7 +8,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, register_autoclass
from llmtuner.model.utils import load_valuehead_params, register_autoclass
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
@ -92,10 +92,9 @@ def load_model_and_tokenizer(
)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
patch_model(model, tokenizer, model_args)
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
if add_valuehead:

View File

@ -144,7 +144,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
_verify_model_args(model_args, finetuning_args)
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):

View File

@ -3,14 +3,14 @@ import math
import torch
import random
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from datasets import load_dataset
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from llmtuner.extras.constants import FILEEXT2TYPE
from llmtuner.extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
@ -180,6 +180,42 @@ def _configure_quantization(
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def _prepare_model_for_training(
model: "PreTrainedModel",
model_args: "ModelArguments",
output_layer_name: Optional[str] = "lm_head"
) -> None:
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
"""
if model_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
logger.info("Upcasting layernorm weights in float32.")
if not model_args.disable_gradient_checkpointing:
if getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
else:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if hasattr(model, output_layer_name):
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
output_layer.register_forward_hook(fp32_forward_post_hook)
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
@ -206,7 +242,12 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
def patch_model(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
is_trainable: bool
) -> None:
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
@ -220,6 +261,10 @@ def patch_model(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", mode
_resize_embedding_layer(model, tokenizer)
if is_trainable:
_prepare_model_for_training(model, model_args)
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel):

View File

@ -1,19 +1,15 @@
import math
import torch
import inspect
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List
from transformers.utils import cached_file
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import get_current_device
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import DataArguments
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)
@ -123,51 +119,6 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
return None
def prepare_model_for_training(
model: "PreTrainedModel",
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
(1) cast the layernorm in fp32
(2) make output embedding layer require grads
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
if finetuning_args.upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)
logger.info("Upcasting weights in layernorm in float32.")
if use_gradient_checkpointing and getattr(model, "supports_gradient_checkpointing", False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
output_layer = getattr(model, output_layer_name)
if isinstance(output_layer, torch.nn.Linear):
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
return args[0].to(output_layer.weight.dtype)
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
return output.to(torch.float32)
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
output_layer.register_forward_hook(fp32_forward_post_hook)
return model
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()