diff --git a/README.md b/README.md index 3c78c971..3537fd10 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[23/12/12] We supported **[LoftQ](https://arxiv.org/abs/2310.08659)** training. Try `loftq_init` argument with `peft==0.7.1.dev0` for better QLoRA performance. + [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage. @@ -67,7 +69,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models. -[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. +[23/09/10] We supported **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings. diff --git a/README_zh.md b/README_zh.md index a91e2571..9527e976 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,6 +55,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 ## 更新日志 +[23/12/12] 我们支持了 **[LoftQ](https://arxiv.org/abs/2310.08659)** 训练。请使用 `loftq_init` 参数以及 `peft==0.7.1.dev0` 以提升 QLoRA 性能。 + [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 @@ -67,13 +69,13 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。 -[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。 +[23/09/10] 我们支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。 [23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。 [23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。 -[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。 +[23/07/31] 我们支持了**数据流式加载**。请使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。 diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index ae3a6f79..87f8f8a4 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -55,6 +55,10 @@ class LoraArguments: Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ Others choices: the same as LLaMA."} ) + loftq_init: Optional[bool] = field( + default=False, + metadata={"help": "Use LoftQ initialization for quantized LoRA fine-tuning."} + ) resume_lora_training: Optional[bool] = field( default=True, metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index 72cea444..9404e707 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -91,6 +91,16 @@ def init_adapter( else: target_modules = finetuning_args.lora_target + config_kwargs = {} + if model_args.quantization_bit is not None and finetuning_args.loftq_init: + if model_args.quantization_bit != 4: + raise ValueError("LoftQ initialization only support 4-bit quantized training.") + + from peft import LoftQConfig # type: ignore + loftq_config = LoftQConfig(loftq_bits=4) + config_kwargs["init_lora_weights"] = "loftq" + config_kwargs["loftq_config"] = loftq_config + lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, @@ -98,7 +108,8 @@ def init_adapter( lora_alpha=finetuning_args.lora_alpha, lora_dropout=finetuning_args.lora_dropout, target_modules=target_modules, - modules_to_save=finetuning_args.additional_target + modules_to_save=finetuning_args.additional_target, + **config_kwargs ) model = get_peft_model(model, lora_config) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index fb9b08f1..e3d6ebba 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -144,28 +144,32 @@ def load_model_and_tokenizer( model_args.quantization_bit = None config_kwargs["device_map"] = {"": get_current_device()} quantization_config = getattr(config, "quantization_config", None) - logger.info("Loading {}-bit quantized model.".format(quantization_config.get("bits", -1))) + logger.info("Loading {}-bit pre-quantized model.".format(quantization_config.get("bits", -1))) - # Quantization configurations (using bitsandbytes library) + # Quantization configurations (using bitsandbytes) if model_args.quantization_bit is not None: if is_deepspeed_zero3_enabled(): raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") - if model_args.quantization_bit == 8: - require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + if finetuning_args.loftq_init: + require_version("peft>=0.7.1.dev0", "To fix: pip install git+https://github.com/hiyouga/peft.git") + logger.info("Skip bnb quantization because using loftq.") + else: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) - if model_args.quantization_bit == 4: - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - config_kwargs["quantization_config"] = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_compute_dtype=model_args.compute_dtype, - bnb_4bit_use_double_quant=model_args.double_quantization, - bnb_4bit_quant_type=model_args.quantization_type - ) + if model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) - config_kwargs["device_map"] = {"": get_current_device()} - logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) + config_kwargs["device_map"] = {"": get_current_device()} + logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) # Load pre-trained models (without valuehead) model = AutoModelForCausalLM.from_pretrained(