support unsloth

This commit is contained in:
hiyouga 2023-12-23 00:14:33 +08:00
parent 315b8367cb
commit 7aad0b889d
11 changed files with 224 additions and 171 deletions

View File

@ -55,6 +55,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation for faster LoRA tuning. Try `--use_unsloth` argument to active unsloth patch. See performance comparisons [here](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.

View File

@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志 ## 更新日志
[23/12/23] 我们支持了使用 **[unsloth](https://github.com/unslothai/unsloth)** 大幅提升 LoRA 训练效率。请使用 `--use_unsloth` 参数启用 unsloth 优化。性能对比请查阅[此处](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。

View File

@ -3,7 +3,7 @@ transformers>=4.36.2
datasets>=2.14.3 datasets>=2.14.3
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.7.0 peft>=0.7.0
trl==0.7.4 trl>=0.7.6
gradio>=3.38.0,<4.0.0 gradio>=3.38.0,<4.0.0
scipy scipy
sentencepiece sentencepiece

View File

@ -38,7 +38,7 @@ class LoraArguments:
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."} metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
) )
lora_dropout: Optional[float] = field( lora_dropout: Optional[float] = field(
default=0.1, default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."} metadata={"help": "Dropout rate for the LoRA fine-tuning."}
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
@ -70,7 +70,7 @@ class RLHFArguments:
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."} metadata={"help": "The beta parameter for the DPO loss."}
) )
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field( dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid", default="sigmoid",
metadata={"help": "The type of DPO loss to use."} metadata={"help": "The type of DPO loss to use."}
) )
@ -133,38 +133,7 @@ class RLHFArguments:
@dataclass @dataclass
class ExportArguments: class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r"""
Arguments pertaining to model exporting.
"""
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."}
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."}
)
export_quantization_maxlen: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
@ -201,7 +170,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.stage == "ppo" and self.reward_model is None: if self.stage == "ppo" and self.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.") raise ValueError("Reward model is necessary for PPO training.")
@ -209,9 +177,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@ -54,6 +54,10 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
) )
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use unsloth's optimization for LoRA training."}
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."} metadata={"help": "Auth token to log in with Hugging Face Hub."}
@ -62,6 +66,30 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."} metadata={"help": "Auth token to log in with ModelScope Hub."}
) )
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
export_size: Optional[int] = field(
default=1,
metadata={"help": "The file shard size (in GB) of the exported model."}
)
export_quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the exported model."}
)
export_quantization_dataset: Optional[str] = field(
default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
)
export_quantization_nsamples: Optional[int] = field(
default=128,
metadata={"help": "The number of samples used for quantization."}
)
export_quantization_maxlen: Optional[int] = field(
default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."}
)
def __post_init__(self): def __post_init__(self):
self.compute_dtype = None self.compute_dtype = None
@ -74,6 +102,10 @@ class ModelArguments:
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -93,16 +93,31 @@ def init_adapter(
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target
lora_config = LoraConfig( peft_kwargs = {
task_type=TaskType.CAUSAL_LM, "r": finetuning_args.lora_rank,
inference_mode=False, "target_modules": target_modules,
r=finetuning_args.lora_rank, "lora_alpha": finetuning_args.lora_alpha,
lora_alpha=finetuning_args.lora_alpha, "lora_dropout": finetuning_args.lora_dropout
lora_dropout=finetuning_args.lora_dropout, }
target_modules=target_modules,
modules_to_save=finetuning_args.additional_target if model_args.use_unsloth:
) from unsloth import FastLlamaModel, FastMistralModel # type: ignore
model = get_peft_model(model, lora_config) unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if getattr(model.config, "model_type", None) == "llama":
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
elif getattr(model.config, "model_type", None) == "mistral":
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
raise NotImplementedError
else:
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
**peft_kwargs
)
model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)

View File

@ -4,10 +4,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
import llmtuner.model.patcher as patcher
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import count_parameters, try_download_model_from_ms from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
from llmtuner.model.adapter import init_adapter from llmtuner.model.adapter import init_adapter
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
from llmtuner.model.utils import ( from llmtuner.model.utils import (
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
) )
@ -24,7 +24,7 @@ require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4") require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def load_model_and_tokenizer( def load_model_and_tokenizer(
@ -52,26 +52,48 @@ def load_model_and_tokenizer(
model_args.model_name_or_path, model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens, split_special_tokens=model_args.split_special_tokens,
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow padding_side="right",
**config_kwargs **config_kwargs
) )
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
patcher.patch_tokenizer(tokenizer) model = None
patcher.patch_config(config, model_args) if is_trainable and model_args.use_unsloth:
patcher.configure_rope(config, model_args, is_trainable) require_version("unsloth==2023.12", "Follow the instructions at: https://github.com/unslothai/unsloth")
patcher.configure_flashattn(config_kwargs, model_args) from unsloth import FastLlamaModel, FastMistralModel # type: ignore
patcher.configure_longlora(config, model_args, is_trainable) unsloth_kwargs = {
patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args) "model_name": model_args.model_name_or_path,
"max_seq_length": model_args.model_max_length,
"load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token,
"device_map": get_current_device(),
"rope_scaling": getattr(config, "rope_scaling", None)
}
if getattr(config, "model_type", None) == "llama":
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
elif getattr(config, "model_type", None) == "mistral":
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
else:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False
if model_args.adapter_name_or_path:
model_args.adapter_name_or_path = None
logger.warning("Unsloth does not support loading adapters.")
if model is None:
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
patcher.patch_model(model) patch_model(model)
register_autoclass(config, model, tokenizer) register_autoclass(config, model, tokenizer)
if not is_deepspeed_zero3_enabled(): if not is_deepspeed_zero3_enabled():
resize_embedding_layer(model, tokenizer) resize_embedding_layer(model, tokenizer)
@ -81,7 +103,7 @@ def load_model_and_tokenizer(
if add_valuehead: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patcher.patch_valuehead_model(model) patch_valuehead_model(model)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
vhead_path = model_args.adapter_name_or_path[-1] vhead_path = model_args.adapter_name_or_path[-1]
@ -94,7 +116,7 @@ def load_model_and_tokenizer(
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False)
model.eval() model.eval()
else: else:
model.train() model.train()

View File

@ -104,8 +104,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
log_level = training_args.get_process_log_level() _set_transformers_logging()
_set_transformers_logging(log_level)
# Check arguments # Check arguments
data_args.init_for_training(training_args.seed) data_args.init_for_training(training_args.seed)

View File

@ -18,85 +18,14 @@ from llmtuner.extras.packages import is_flash_attn2_available
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments, FinetuningArguments from llmtuner.hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"): def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.flash_attn and is_flash_attn2_available():
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def configure_quantization(
config: "PretrainedConfig",
config_kwargs: Dict[str, Any],
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments"
):
r"""
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq or awq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
elif finetuning_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
config_kwargs["quantization_config"] = GPTQConfig(
bits=finetuning_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if model_args.rope_scaling is not None: if model_args.rope_scaling is not None:
if not hasattr(config, "rope_scaling"): if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.") logger.warning("Current model does not support RoPE scaling.")
@ -123,27 +52,94 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
)) ))
def get_quantization_dataset( def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]):
if model_args.flash_attn and is_flash_attn2_available():
config_kwargs["use_flash_attention_2"] = True
config_kwargs["torch_dtype"] = model_args.compute_dtype
logger.info("Using FlashAttention-2 for faster training and inference.")
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
if is_trainable and model_args.shift_attn:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25)
logger.info("Using shift short attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
def _configure_quantization(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments" config_kwargs: Dict[str, Any]
) -> List[str]: ):
r"""
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
"""
if getattr(config, "quantization_config", None): # gptq or awq
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
config_kwargs["device_map"] = {"": get_current_device()}
quantization_config = getattr(config, "quantization_config", None)
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
elif model_args.export_quantization_bit is not None: # gptq
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
from accelerate.utils import get_max_memory
if getattr(config, "model_type", None) == "chatglm":
raise ValueError("ChatGLM model is not supported.")
config_kwargs["quantization_config"] = GPTQConfig(
bits=model_args.export_quantization_bit,
tokenizer=tokenizer,
dataset=get_quantization_dataset(tokenizer, model_args)
)
config_kwargs["device_map"] = "auto"
config_kwargs["max_memory"] = get_max_memory()
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
r""" r"""
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
""" """
if os.path.isfile(finetuning_args.export_quantization_dataset): if os.path.isfile(model_args.export_quantization_dataset):
data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None) data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
data_files = finetuning_args.export_quantization_dataset data_files = model_args.export_quantization_dataset
else: else:
data_path = finetuning_args.export_quantization_dataset data_path = model_args.export_quantization_dataset
data_files = None data_files = None
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
maxlen = finetuning_args.export_quantization_maxlen maxlen = model_args.export_quantization_maxlen
samples = [] samples = []
for _ in range(finetuning_args.export_quantization_nsamples): for _ in range(model_args.export_quantization_nsamples):
while True: while True:
sample_idx = random.randint(0, len(dataset) - 1) sample_idx = random.randint(0, len(dataset) - 1)
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
@ -151,13 +147,24 @@ def get_quantization_dataset(
break # TODO: fix large maxlen break # TODO: fix large maxlen
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen] input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
return samples return samples
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"): def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool
):
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype) setattr(config, "torch_dtype", model_args.compute_dtype)
@ -166,6 +173,11 @@ def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype) setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
_configure_rope(config, model_args, is_trainable)
_configure_flashattn(model_args, config_kwargs)
_configure_longlora(config, model_args, is_trainable)
_configure_quantization(config, tokenizer, model_args, config_kwargs)
def patch_model(model: "PreTrainedModel"): def patch_model(model: "PreTrainedModel"):
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
@ -177,15 +189,15 @@ def patch_model(model: "PreTrainedModel"):
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"): def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
return self.pretrained_model.get_input_embeddings() if isinstance(self.pretrained_model, PreTrainedModel):
self.pretrained_model.tie_weights()
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings()
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

View File

@ -16,7 +16,7 @@ class CustomDPOTrainer(DPOTrainer):
def __init__( def __init__(
self, self,
beta: float, beta: float,
loss_type: Literal["sigmoid", "hinge"], loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
ftx_gamma: float, ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
@ -28,16 +28,20 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None: if ref_model is not None:
disable_dropout_in_model(ref_model) disable_dropout_in_model(ref_model)
self.is_encoder_decoder = model.config.is_encoder_decoder
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning self.use_dpo_data_collator = True # hack to avoid warning
self.generate_during_eval = False # disable at evaluation self.generate_during_eval = False # disable at evaluation
self.label_pad_token_id = IGNORE_INDEX self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0 self.padding_value = 0
self.is_encoder_decoder = model.config.is_encoder_decoder
self.precompute_ref_log_probs = False
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
self.ref_model = ref_model
self.beta = beta self.beta = beta
self.label_smoothing = 0 self.label_smoothing = 0
self.ftx_gamma = ftx_gamma
self.loss_type = loss_type self.loss_type = loss_type
self.ftx_gamma = ftx_gamma
self._stored_metrics = defaultdict(lambda: defaultdict(list)) self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs) Trainer.__init__(self, model=model, **kwargs)
@ -95,7 +99,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],

View File

@ -37,7 +37,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None): def export_model(args: Optional[Dict[str, Any]] = None):
model_args, _, finetuning_args, _ = get_infer_args(args) model_args, _, finetuning_args, _ = get_infer_args(args)
if model_args.adapter_name_or_path is not None and finetuning_args.export_quantization_bit is not None: if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.") raise ValueError("Please merge adapters before quantizing the model.")
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
@ -47,12 +47,12 @@ def export_model(args: Optional[Dict[str, Any]] = None):
model.config.use_cache = True model.config.use_cache = True
model = model.to("cpu") model = model.to("cpu")
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size)) model.save_pretrained(model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size))
try: try:
tokenizer.padding_side = "left" # restore padding side tokenizer.padding_side = "left" # restore padding side
tokenizer.init_kwargs["padding_side"] = "left" tokenizer.init_kwargs["padding_side"] = "left"
tokenizer.save_pretrained(finetuning_args.export_dir) tokenizer.save_pretrained(model_args.export_dir)
except: except:
logger.warning("Cannot save tokenizer, please copy the files manually.") logger.warning("Cannot save tokenizer, please copy the files manually.")