add configurer

This commit is contained in:
hiyouga 2023-12-15 21:46:40 +08:00
parent 0716f5e470
commit 2740aa9cbb
3 changed files with 83 additions and 77 deletions

View File

@ -1,20 +1,20 @@
from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead
import llmtuner.model.patcher as patcher
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.hparams import FinetuningArguments
from llmtuner.extras.misc import count_parameters, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter
from llmtuner.model.patches import patch_config, patch_model, patch_valuehead_model, patch_tokenizer, register_autoclass
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
from llmtuner.model.utils import (
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
)
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import ModelArguments
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
@ -55,45 +55,15 @@ def load_model_and_tokenizer(
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs
)
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, model_args, is_trainable)
# Set FlashAttention-2
if model_args.flash_attn and is_flash_attn2_available():
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
patcher.patch_tokenizer(tokenizer)
patcher.patch_config(config, model_args, is_trainable)
patcher.configure_rope(config, model_args, is_trainable)
patcher.configure_flashattn(config, model_args)
patcher.configure_longlora(config, model_args, is_trainable)
patcher.configure_quantization(config, config_kwargs, model_args)
# Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None):
model_args.quantization_bit = None # remove bnb quantization
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
# Quantization configurations (using bitsandbytes)
if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
@ -101,23 +71,20 @@ def load_model_and_tokenizer(
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
patch_model(model)
patcher.patch_model(model)
register_autoclass(config, model, tokenizer)
resize_embedding_layer(model, tokenizer)
# Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable)
# Prepare model with valuehead for RLHF
if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_valuehead_model(model)
patcher.patch_valuehead_model(model)
vhead_params = load_valuehead_params(model_args)
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
# Prepare model for inference
if not is_trainable:
model.requires_grad_(False) # fix all model params
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model

View File

@ -1,12 +1,15 @@
import math
import torch
from types import MethodType
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import BitsAndBytesConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import infer_optim_dtype
from llmtuner.extras.misc import get_current_device, infer_optim_dtype
from llmtuner.extras.packages import is_flash_attn2_available
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
@ -15,17 +18,53 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
if model_args.flash_attn and is_flash_attn2_available():
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def configure_quantization(config: "PretrainedConfig", config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
if getattr(config, "quantization_config", None): # gptq or awq
model_args.quantization_bit = None # remove bnb quantization
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
if model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
if model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.")
@ -51,14 +90,15 @@ def patch_config(config: "PretrainedConfig", model_args: "ModelArguments", is_tr
model_args.rope_scaling, scaling_factor
))
# Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn:
logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
# if getattr(config, "model_type", None) == "llama":
# setattr(config, "group_size_ratio", 0.25)
# logger.info("Using shift short attention with group_size_ratio=1/4.")
# else:
# logger.warning("Current model does not support shift short attention.")
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
if getattr(config, "model_type", None) == "qwen":
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
def patch_model(model: "PreTrainedModel"):
@ -83,12 +123,3 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizerBase"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

View File

@ -9,8 +9,7 @@ from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from llmtuner.hparams import DataArguments
@ -183,3 +182,12 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
new_embedding_size = model.get_input_embeddings().weight.size(0)
logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
if "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()