support DoRA, AWQ, AQLM #2512

This commit is contained in:
hiyouga 2024-02-28 19:53:28 +08:00
parent 511b15b96a
commit cfefacaa37
9 changed files with 40 additions and 9 deletions

View File

@ -56,6 +56,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.

View File

@ -56,6 +56,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。

View File

@ -2,8 +2,8 @@ torch>=1.13.1
transformers>=4.37.2
datasets>=2.14.3
accelerate>=0.27.2
peft>=0.8.2
trl>=0.7.6
peft>=0.9.0
trl>=0.7.11
gradio>=3.38.0,<4.0.0
scipy
einops

View File

@ -842,10 +842,18 @@ register_model_group(
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
},
"Yi-6B-int4-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
},
"Yi-34B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
},
"Yi-34B-int4-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
},
},
template="yi",
)

View File

@ -74,6 +74,9 @@ class LoraArguments:
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
)
use_dora: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
)
create_new_adapter: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},

View File

@ -36,9 +36,9 @@ def _check_dependencies(disabled: bool) -> None:
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
@ -144,7 +144,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
if finetuning_args.use_dora:
if model_args.quantization_bit is not None:
raise ValueError("DoRA does not support quantization.")
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)

View File

@ -85,7 +85,7 @@ def init_adapter(
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
if model_args.adapter_name_or_path is not None:
@ -123,6 +123,10 @@ def init_adapter(
if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora:
if getattr(model, "quantization_method", None):
raise ValueError("DoRA is currently not compatible with quantized models.")
peft_kwargs = {
"r": finetuning_args.lora_rank,
"target_modules": target_modules,
@ -141,6 +145,7 @@ def init_adapter(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora,
**peft_kwargs,
)
model = get_peft_model(model, lora_config)

View File

@ -51,7 +51,7 @@ def load_model_and_tokenizer(
patch_tokenizer(tokenizer)
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
patch_config(config, tokenizer, model_args, finetuning_args, config_kwargs, is_trainable)
model = None
if is_trainable and model_args.use_unsloth:

View File

@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@ -253,6 +253,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@ -273,6 +274,9 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
if finetuning_args.use_dora:
config_kwargs["device_map"] = {"": get_current_device()}
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool