support unsloth
This commit is contained in:
parent
315b8367cb
commit
7aad0b889d
|
@ -55,6 +55,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation for faster LoRA tuning. Try `--use_unsloth` argument to active unsloth patch. See performance comparisons [here](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison).
|
||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||||
|
|
|
@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[23/12/23] 我们支持了使用 **[unsloth](https://github.com/unslothai/unsloth)** 大幅提升 LoRA 训练效率。请使用 `--use_unsloth` 参数启用 unsloth 优化。性能对比请查阅[此处](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||||
|
|
||||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
|
@ -3,7 +3,7 @@ transformers>=4.36.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.7.0
|
peft>=0.7.0
|
||||||
trl==0.7.4
|
trl>=0.7.6
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|
|
@ -38,7 +38,7 @@ class LoraArguments:
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
|
||||||
)
|
)
|
||||||
lora_dropout: Optional[float] = field(
|
lora_dropout: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.0,
|
||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
|
@ -70,7 +70,7 @@ class RLHFArguments:
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
|
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||||
default="sigmoid",
|
default="sigmoid",
|
||||||
metadata={"help": "The type of DPO loss to use."}
|
metadata={"help": "The type of DPO loss to use."}
|
||||||
)
|
)
|
||||||
|
@ -133,38 +133,7 @@ class RLHFArguments:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExportArguments:
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
r"""
|
|
||||||
Arguments pertaining to model exporting.
|
|
||||||
"""
|
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
|
||||||
export_size: Optional[int] = field(
|
|
||||||
default=1,
|
|
||||||
metadata={"help": "The file shard size (in GB) of the exported model."}
|
|
||||||
)
|
|
||||||
export_quantization_bit: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The number of bits to quantize the exported model."}
|
|
||||||
)
|
|
||||||
export_quantization_dataset: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
|
||||||
)
|
|
||||||
export_quantization_nsamples: Optional[int] = field(
|
|
||||||
default=128,
|
|
||||||
metadata={"help": "The number of samples used for quantization."}
|
|
||||||
)
|
|
||||||
export_quantization_maxlen: Optional[int] = field(
|
|
||||||
default=1024,
|
|
||||||
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportArguments):
|
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
@ -201,7 +170,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("Reward model is necessary for PPO training.")
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
@ -209,9 +177,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, ExportA
|
||||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||||
|
|
||||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
|
||||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
|
@ -54,6 +54,10 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
|
use_unsloth: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use unsloth's optimization for LoRA training."}
|
||||||
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
|
@ -62,6 +66,30 @@ class ModelArguments:
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."}
|
metadata={"help": "Auth token to log in with ModelScope Hub."}
|
||||||
)
|
)
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
|
)
|
||||||
|
export_size: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The file shard size (in GB) of the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
|
||||||
|
)
|
||||||
|
export_quantization_nsamples: Optional[int] = field(
|
||||||
|
default=128,
|
||||||
|
metadata={"help": "The number of samples used for quantization."}
|
||||||
|
)
|
||||||
|
export_quantization_maxlen: Optional[int] = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={"help": "The maximum length of the model inputs used for quantization."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
@ -74,6 +102,10 @@ class ModelArguments:
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||||
|
|
||||||
|
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||||
|
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
|
@ -93,16 +93,31 @@ def init_adapter(
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
peft_kwargs = {
|
||||||
task_type=TaskType.CAUSAL_LM,
|
"r": finetuning_args.lora_rank,
|
||||||
inference_mode=False,
|
"target_modules": target_modules,
|
||||||
r=finetuning_args.lora_rank,
|
"lora_alpha": finetuning_args.lora_alpha,
|
||||||
lora_alpha=finetuning_args.lora_alpha,
|
"lora_dropout": finetuning_args.lora_dropout
|
||||||
lora_dropout=finetuning_args.lora_dropout,
|
}
|
||||||
target_modules=target_modules,
|
|
||||||
modules_to_save=finetuning_args.additional_target
|
if model_args.use_unsloth:
|
||||||
)
|
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||||
model = get_peft_model(model, lora_config)
|
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
||||||
|
if getattr(model.config, "model_type", None) == "llama":
|
||||||
|
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||||
|
elif getattr(model.config, "model_type", None) == "mistral":
|
||||||
|
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
else:
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=False,
|
||||||
|
modules_to_save=finetuning_args.additional_target,
|
||||||
|
**peft_kwargs
|
||||||
|
)
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
|
@ -4,10 +4,10 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
import llmtuner.model.patcher as patcher
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, try_download_model_from_ms
|
from llmtuner.extras.misc import count_parameters, get_current_device, try_download_model_from_ms
|
||||||
from llmtuner.model.adapter import init_adapter
|
from llmtuner.model.adapter import init_adapter
|
||||||
|
from llmtuner.model.patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
|
||||||
from llmtuner.model.utils import (
|
from llmtuner.model.utils import (
|
||||||
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
|
load_valuehead_params, prepare_model_for_training, resize_embedding_layer, register_autoclass
|
||||||
)
|
)
|
||||||
|
@ -24,7 +24,7 @@ require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4")
|
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
@ -52,26 +52,48 @@ def load_model_and_tokenizer(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
split_special_tokens=model_args.split_special_tokens,
|
split_special_tokens=model_args.split_special_tokens,
|
||||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
padding_side="right",
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
patch_tokenizer(tokenizer)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
|
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
||||||
|
|
||||||
patcher.patch_tokenizer(tokenizer)
|
model = None
|
||||||
patcher.patch_config(config, model_args)
|
if is_trainable and model_args.use_unsloth:
|
||||||
patcher.configure_rope(config, model_args, is_trainable)
|
require_version("unsloth==2023.12", "Follow the instructions at: https://github.com/unslothai/unsloth")
|
||||||
patcher.configure_flashattn(config_kwargs, model_args)
|
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
|
||||||
patcher.configure_longlora(config, model_args, is_trainable)
|
unsloth_kwargs = {
|
||||||
patcher.configure_quantization(config, config_kwargs, tokenizer, model_args, finetuning_args)
|
"model_name": model_args.model_name_or_path,
|
||||||
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
"load_in_4bit": model_args.quantization_bit == 4,
|
||||||
|
"token": model_args.hf_hub_token,
|
||||||
|
"device_map": get_current_device(),
|
||||||
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
||||||
|
}
|
||||||
|
if getattr(config, "model_type", None) == "llama":
|
||||||
|
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
elif getattr(config, "model_type", None) == "mistral":
|
||||||
|
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
else:
|
||||||
|
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
|
model_args.use_unsloth = False
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path:
|
||||||
|
model_args.adapter_name_or_path = None
|
||||||
|
logger.warning("Unsloth does not support loading adapters.")
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
config=config,
|
||||||
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
|
**config_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
config=config,
|
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
|
||||||
**config_kwargs
|
|
||||||
)
|
|
||||||
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||||
patcher.patch_model(model)
|
patch_model(model)
|
||||||
register_autoclass(config, model, tokenizer)
|
register_autoclass(config, model, tokenizer)
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
resize_embedding_layer(model, tokenizer)
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
@ -81,7 +103,7 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
patcher.patch_valuehead_model(model)
|
patch_valuehead_model(model)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None:
|
if model_args.adapter_name_or_path is not None:
|
||||||
vhead_path = model_args.adapter_name_or_path[-1]
|
vhead_path = model_args.adapter_name_or_path[-1]
|
||||||
|
@ -94,7 +116,7 @@ def load_model_and_tokenizer(
|
||||||
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -104,8 +104,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
log_level = training_args.get_process_log_level()
|
_set_transformers_logging()
|
||||||
_set_transformers_logging(log_level)
|
|
||||||
|
|
||||||
# Check arguments
|
# Check arguments
|
||||||
data_args.init_for_training(training_args.seed)
|
data_args.init_for_training(training_args.seed)
|
||||||
|
|
|
@ -18,85 +18,14 @@ from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
|
SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
|
||||||
|
|
||||||
|
|
||||||
def configure_flashattn(config_kwargs: Dict[str, Any], model_args: "ModelArguments"):
|
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||||
if model_args.flash_attn and is_flash_attn2_available():
|
|
||||||
config_kwargs["use_flash_attention_2"] = True
|
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
|
||||||
|
|
||||||
|
|
||||||
def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
|
||||||
if is_trainable and model_args.shift_attn:
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
|
||||||
else:
|
|
||||||
logger.warning("Current model does not support shift short attention.")
|
|
||||||
|
|
||||||
|
|
||||||
def configure_quantization(
|
|
||||||
config: "PretrainedConfig",
|
|
||||||
config_kwargs: Dict[str, Any],
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
model_args: "ModelArguments",
|
|
||||||
finetuning_args: "FinetuningArguments"
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
|
||||||
"""
|
|
||||||
if getattr(config, "quantization_config", None): # gptq or awq
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
|
||||||
quantization_config = getattr(config, "quantization_config", None)
|
|
||||||
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
|
||||||
|
|
||||||
elif finetuning_args.export_quantization_bit is not None: # gptq
|
|
||||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
|
||||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
|
||||||
from accelerate.utils import get_max_memory
|
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "chatglm":
|
|
||||||
raise ValueError("ChatGLM model is not supported.")
|
|
||||||
|
|
||||||
config_kwargs["quantization_config"] = GPTQConfig(
|
|
||||||
bits=finetuning_args.export_quantization_bit,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
dataset=get_quantization_dataset(tokenizer, model_args, finetuning_args)
|
|
||||||
)
|
|
||||||
config_kwargs["device_map"] = "auto"
|
|
||||||
config_kwargs["max_memory"] = get_max_memory()
|
|
||||||
logger.info("Quantizing model to {} bit.".format(finetuning_args.export_quantization_bit))
|
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
|
||||||
)
|
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
||||||
|
|
||||||
|
|
||||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
|
||||||
if model_args.rope_scaling is not None:
|
if model_args.rope_scaling is not None:
|
||||||
if not hasattr(config, "rope_scaling"):
|
if not hasattr(config, "rope_scaling"):
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
@ -123,27 +52,94 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_dataset(
|
def _configure_flashattn(model_args: "ModelArguments", config_kwargs: Dict[str, Any]):
|
||||||
|
if model_args.flash_attn and is_flash_attn2_available():
|
||||||
|
config_kwargs["use_flash_attention_2"] = True
|
||||||
|
config_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool):
|
||||||
|
if is_trainable and model_args.shift_attn:
|
||||||
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
|
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_quantization(
|
||||||
|
config: "PretrainedConfig",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments"
|
config_kwargs: Dict[str, Any]
|
||||||
) -> List[str]:
|
):
|
||||||
|
r"""
|
||||||
|
Priority: Pre-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||||
|
"""
|
||||||
|
if getattr(config, "quantization_config", None): # gptq or awq
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
quantization_config = getattr(config, "quantization_config", None)
|
||||||
|
logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1)))
|
||||||
|
|
||||||
|
elif model_args.export_quantization_bit is not None: # gptq
|
||||||
|
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||||
|
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||||
|
from accelerate.utils import get_max_memory
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "chatglm":
|
||||||
|
raise ValueError("ChatGLM model is not supported.")
|
||||||
|
|
||||||
|
config_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
bits=model_args.export_quantization_bit,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
dataset=get_quantization_dataset(tokenizer, model_args)
|
||||||
|
)
|
||||||
|
config_kwargs["device_map"] = "auto"
|
||||||
|
config_kwargs["max_memory"] = get_max_memory()
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit == 8:
|
||||||
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
|
||||||
|
elif model_args.quantization_bit == 4:
|
||||||
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
|
)
|
||||||
|
|
||||||
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
|
|
||||||
|
def get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
|
||||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(finetuning_args.export_quantization_dataset):
|
if os.path.isfile(model_args.export_quantization_dataset):
|
||||||
data_path = FILEEXT2TYPE.get(finetuning_args.export_quantization_dataset.split(".")[-1], None)
|
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||||
data_files = finetuning_args.export_quantization_dataset
|
data_files = model_args.export_quantization_dataset
|
||||||
else:
|
else:
|
||||||
data_path = finetuning_args.export_quantization_dataset
|
data_path = model_args.export_quantization_dataset
|
||||||
data_files = None
|
data_files = None
|
||||||
|
|
||||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||||
maxlen = finetuning_args.export_quantization_maxlen
|
maxlen = model_args.export_quantization_maxlen
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for _ in range(finetuning_args.export_quantization_nsamples):
|
for _ in range(model_args.export_quantization_nsamples):
|
||||||
while True:
|
while True:
|
||||||
sample_idx = random.randint(0, len(dataset) - 1)
|
sample_idx = random.randint(0, len(dataset) - 1)
|
||||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||||
|
@ -151,13 +147,24 @@ def get_quantization_dataset(
|
||||||
break # TODO: fix large maxlen
|
break # TODO: fix large maxlen
|
||||||
|
|
||||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||||
input_ids = sample["input_ids"][:, word_idx:word_idx+maxlen]
|
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
|
||||||
|
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||||
|
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_config(
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
config_kwargs: Dict[str, Any],
|
||||||
|
is_trainable: bool
|
||||||
|
):
|
||||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||||
|
@ -166,6 +173,11 @@ def patch_config(config: "PretrainedConfig", model_args: "ModelArguments"):
|
||||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||||
|
|
||||||
|
_configure_rope(config, model_args, is_trainable)
|
||||||
|
_configure_flashattn(model_args, config_kwargs)
|
||||||
|
_configure_longlora(config, model_args, is_trainable)
|
||||||
|
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def patch_model(model: "PreTrainedModel"):
|
def patch_model(model: "PreTrainedModel"):
|
||||||
if "GenerationMixin" not in str(model.generate.__func__):
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
|
@ -177,15 +189,15 @@ def patch_model(model: "PreTrainedModel"):
|
||||||
|
|
||||||
|
|
||||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
|
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead"):
|
||||||
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
return self.pretrained_model.get_input_embeddings()
|
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||||
|
self.pretrained_model.tie_weights()
|
||||||
|
|
||||||
|
def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||||
|
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||||
|
return self.pretrained_model.get_input_embeddings()
|
||||||
|
|
||||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
|
||||||
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
|
||||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||||
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
|
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||||
|
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer"):
|
|
||||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
|
||||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
loss_type: Literal["sigmoid", "hinge"],
|
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
||||||
ftx_gamma: float,
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
|
@ -28,16 +28,20 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
disable_dropout_in_model(ref_model)
|
disable_dropout_in_model(ref_model)
|
||||||
|
|
||||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
|
||||||
self.ref_model = ref_model
|
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.generate_during_eval = False # disable at evaluation
|
self.generate_during_eval = False # disable at evaluation
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
self.padding_value = 0
|
self.padding_value = 0
|
||||||
|
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||||
|
self.precompute_ref_log_probs = False
|
||||||
|
self._precomputed_train_ref_log_probs = False
|
||||||
|
self._precomputed_eval_ref_log_probs = False
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.label_smoothing = 0
|
self.label_smoothing = 0
|
||||||
self.ftx_gamma = ftx_gamma
|
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
|
self.ftx_gamma = ftx_gamma
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
Trainer.__init__(self, model=model, **kwargs)
|
Trainer.__init__(self, model=model, **kwargs)
|
||||||
|
@ -95,7 +99,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
|
|
||||||
def get_batch_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
batch: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
|
|
|
@ -37,7 +37,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.export_quantization_bit is not None:
|
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||||
raise ValueError("Please merge adapters before quantizing the model.")
|
raise ValueError("Please merge adapters before quantizing the model.")
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
@ -47,12 +47,12 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
|
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model = model.to("cpu")
|
model = model.to("cpu")
|
||||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size="{}GB".format(finetuning_args.export_size))
|
model.save_pretrained(model_args.export_dir, max_shard_size="{}GB".format(model_args.export_size))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(finetuning_args.export_dir)
|
tokenizer.save_pretrained(model_args.export_dir)
|
||||||
except:
|
except:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue