support DoRA, AWQ, AQLM #2512
This commit is contained in:
parent
511b15b96a
commit
cfefacaa37
|
@ -56,6 +56,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||
|
||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
|
||||
|
||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||
|
|
|
@ -56,6 +56,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||
|
||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。
|
||||
|
||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||
|
|
|
@ -2,8 +2,8 @@ torch>=1.13.1
|
|||
transformers>=4.37.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.27.2
|
||||
peft>=0.8.2
|
||||
trl>=0.7.6
|
||||
peft>=0.9.0
|
||||
trl>=0.7.11
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
einops
|
||||
|
|
|
@ -842,10 +842,18 @@ register_model_group(
|
|||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
|
|
|
@ -74,6 +74,9 @@ class LoraArguments:
|
|||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
)
|
||||
use_dora: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
|
||||
)
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
|
|
|
@ -36,9 +36,9 @@ def _check_dependencies(disabled: bool) -> None:
|
|||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
|
||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
|
@ -144,7 +144,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
||||
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("DoRA does not support quantization.")
|
||||
|
||||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||
|
|
|
@ -85,7 +85,7 @@ def init_adapter(
|
|||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
|
@ -123,6 +123,10 @@ def init_adapter(
|
|||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
|
@ -141,6 +145,7 @@ def init_adapter(
|
|||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
|
|
@ -51,7 +51,7 @@ def load_model_and_tokenizer(
|
|||
patch_tokenizer(tokenizer)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
|
||||
patch_config(config, tokenizer, model_args, finetuning_args, config_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
if is_trainable and model_args.use_unsloth:
|
||||
|
|
|
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
@ -253,6 +253,7 @@ def patch_config(
|
|||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
|
@ -273,6 +274,9 @@ def patch_config(
|
|||
|
||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
|
||||
def patch_model(
|
||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||
|
|
Loading…
Reference in New Issue