improve model export
This commit is contained in:
parent
f6fdd83f8a
commit
d2a676c8ba
|
@ -192,7 +192,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
))
|
))
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
if (
|
||||||
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
and training_args.resume_from_checkpoint is not None
|
||||||
|
):
|
||||||
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
))
|
))
|
||||||
|
|
|
@ -83,46 +83,47 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||||
|
|
||||||
|
|
||||||
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if model_args.rope_scaling is not None:
|
if not hasattr(config, "rope_scaling"):
|
||||||
if not hasattr(config, "rope_scaling"):
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
return
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
if model_args.rope_scaling == "dynamic":
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||||
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
else:
|
else:
|
||||||
if is_trainable:
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||||
if model_args.rope_scaling == "dynamic":
|
scaling_factor = 1.0
|
||||||
logger.warning(
|
else:
|
||||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
scaling_factor = 2.0
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
model_args.rope_scaling, scaling_factor
|
||||||
else:
|
))
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
|
||||||
scaling_factor = 1.0
|
|
||||||
else:
|
|
||||||
scaling_factor = 2.0
|
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
|
||||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
|
||||||
model_args.rope_scaling, scaling_factor
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
|
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
|
||||||
if model_args.flash_attn and is_flash_attn2_available():
|
if not is_flash_attn2_available():
|
||||||
config_kwargs["use_flash_attention_2"] = True
|
logger.warning("FlashAttention2 is not installed.")
|
||||||
config_kwargs["torch_dtype"] = model_args.compute_dtype
|
return
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
|
config_kwargs["use_flash_attention_2"] = True
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
|
||||||
|
|
||||||
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def _configure_longlora(config: "PretrainedConfig") -> None:
|
||||||
if is_trainable and model_args.shift_attn:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
else:
|
||||||
else:
|
logger.warning("Current model does not support shift short attention.")
|
||||||
logger.warning("Current model does not support shift short attention.")
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_quantization(
|
def _configure_quantization(
|
||||||
|
@ -132,9 +133,9 @@ def _configure_quantization(
|
||||||
config_kwargs: Dict[str, Any]
|
config_kwargs: Dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # gptq or awq
|
if getattr(config, "quantization_config", None): # gptq
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
@ -142,9 +143,9 @@ def _configure_quantization(
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
from accelerate.utils import get_max_memory
|
from accelerate.utils import get_max_memory
|
||||||
|
@ -232,15 +233,20 @@ def patch_config(
|
||||||
) -> None:
|
) -> None:
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
setattr(config, dtype_name, model_args.compute_dtype == dtype)
|
||||||
|
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
_configure_rope(config, model_args, is_trainable)
|
||||||
|
|
||||||
|
if model_args.flash_attn:
|
||||||
|
_configure_flashattn(config_kwargs)
|
||||||
|
|
||||||
|
if is_trainable and model_args.shift_attn:
|
||||||
|
_configure_longlora(config)
|
||||||
|
|
||||||
_configure_rope(config, model_args, is_trainable)
|
|
||||||
_configure_flashattn(model_args, config_kwargs)
|
|
||||||
_configure_longlora(config, model_args, is_trainable)
|
|
||||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
@ -46,7 +47,12 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
logger.warning("Cannot merge adapters to a quantized model.")
|
logger.warning("Cannot merge adapters to a quantized model.")
|
||||||
|
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model = model.to("cpu")
|
if getattr(model.config, "torch_dtype", None) == "bfloat16":
|
||||||
|
model = model.to(torch.bfloat16).to("cpu")
|
||||||
|
else:
|
||||||
|
model = model.to(torch.float16).to("cpu")
|
||||||
|
setattr(model.config, "torch_dtype", "float16")
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
max_shard_size="{}GB".format(model_args.export_size),
|
max_shard_size="{}GB".format(model_args.export_size),
|
||||||
|
|
Loading…
Reference in New Issue