add loftq

This commit is contained in:
hiyouga 2023-12-14 21:53:56 +08:00
parent bfdee1608f
commit d4c351f1ec
4 changed files with 90 additions and 14 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 173 KiB

After

Width:  |  Height:  |  Size: 145 KiB

View File

@ -119,16 +119,6 @@ def load_model_and_tokenizer(
model_args.rope_scaling, scaling_factor model_args.rope_scaling, scaling_factor
)) ))
# Set FlashAttention-2
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Current model automatically enables FlashAttention if installed.")
else:
setattr(config, "attn_implementation", "flash_attention_2")
logger.info("Using FlashAttention-2 for faster training and inference.")
# Set shift short attention (S^2-Attn) # Set shift short attention (S^2-Attn)
if is_trainable and model_args.shift_attn: if is_trainable and model_args.shift_attn:
logger.warning("Shift short attention is temporarily invalid due to breaking changes.") logger.warning("Shift short attention is temporarily invalid due to breaking changes.")
@ -138,10 +128,19 @@ def load_model_and_tokenizer(
# else: # else:
# logger.warning("Current model does not support shift short attention.") # logger.warning("Current model does not support shift short attention.")
# Set FlashAttention-2
if model_args.flash_attn:
if not is_flash_attn2_available():
logger.warning("FlashAttention-2 is not installed.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Current model automatically enables FlashAttention if installed.")
else:
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
# Quantization configurations (using gptq or awq) # Quantization configurations (using gptq or awq)
if getattr(config, "quantization_config", None): if getattr(config, "quantization_config", None):
if model_args.quantization_bit is not None: # remove bnb quantization model_args.quantization_bit = None # remove bnb quantization
model_args.quantization_bit = None
config_kwargs["device_map"] = {"": get_current_device()} config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None) quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))

View File

@ -26,7 +26,7 @@ def calculate_lr(
cutoff_len: int, # i.e. maximum input length during training cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate, is_mistral: bool, # mistral model uses a smaller learning rate,
dataset_dir: Optional[str] = "data" dataset_dir: Optional[str] = "../data"
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft", stage="sft",
@ -38,7 +38,7 @@ def calculate_lr(
output_dir="dummy_dir" output_dir="dummy_dir"
)) ))
trainset = get_dataset(model_args, data_args) trainset = get_dataset(model_args, data_args)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft") trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
dataloader = DataLoader( dataloader = DataLoader(

77
tests/loftq_init.py Normal file
View File

@ -0,0 +1,77 @@
# coding=utf-8
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py
import os
import fire
import torch
import torch.nn as nn
from typing import Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]):
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name)
child_module = getattr(parent_module, child_name)
base_layer = getattr(child_module, "base_layer")
weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.")
def quantize_loftq(
model_name_or_path: str,
save_dir: str,
loftq_bits: Optional[int] = 4,
loftq_iter: Optional[int] = 1,
lora_alpha: Optional[int] = None,
lora_rank: Optional[int] = 16,
lora_target: Optional[str] = "q_proj,v_proj"
):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=0.1,
target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq",
loftq_config=loftq_config
)
# Init LoftQ model
lora_model = get_peft_model(model, lora_config)
base_model = lora_model.get_base_model()
# Save LoftQ model
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir)
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True)
lora_model.save_pretrained(os.path.join(save_dir, "adapters"))
# Save base model
unwrap_model(base_model)
base_model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
if __name__ == "__main__":
fire.Fire(quantize_loftq)