From 3b9eee8cd26cfeef945155815175831dec98eb20 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 4 Jun 2023 00:08:56 +0800 Subject: [PATCH] support QLoRA --- README.md | 11 +++++++++-- src/utils/common.py | 44 ++++++++++++++++++++++++++++++++++---------- src/utils/config.py | 5 +++++ 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 3deaf23b..d38ec043 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,14 @@ ## Changelog +[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) + [23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. ## Supported Models -- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B) -- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B) +- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) +- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) ## Supported Training Approaches @@ -22,12 +24,15 @@ - Full-parameter training - Partial-parameter training - [LoRA](https://arxiv.org/abs/2106.09685) + - [QLoRA](https://arxiv.org/abs/2305.14314) - [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) - Full-parameter training - Partial-parameter training - [LoRA](https://arxiv.org/abs/2106.09685) + - [QLoRA](https://arxiv.org/abs/2305.14314) - [RLHF](https://arxiv.org/abs/2203.02155) - [LoRA](https://arxiv.org/abs/2106.09685) + - [QLoRA](https://arxiv.org/abs/2305.14314) ## Provided Datasets @@ -209,6 +214,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ --predict_with_generate ``` +We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation. + ### CLI Demo ```bash diff --git a/src/utils/common.py b/src/utils/common.py index a0b9b551..57143195 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -11,7 +11,8 @@ from transformers import ( AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, - Seq2SeqTrainingArguments + Seq2SeqTrainingArguments, + BitsAndBytesConfig ) from transformers.utils import check_min_version from transformers.utils.versions import require_version @@ -167,12 +168,27 @@ def load_pretrained( # Quantization configurations (using bitsandbytes library). if model_args.quantization_bit is not None: - assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." - require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - #require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") - #require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") - #require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") - config_kwargs["load_in_8bit"] = True + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + config_kwargs["load_in_8bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_threshold=6.0 + ) + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") + require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") + require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") + config_kwargs["load_in_4bit"] = True + config_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=finetuning_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type + ) + else: + raise NotImplementedError is_mergeable = False logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -183,7 +199,7 @@ def load_pretrained( model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, - torch_dtype=torch.float16, # the model weights are float16 type + torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16, low_cpu_mem_usage=True, **config_kwargs ) @@ -237,13 +253,13 @@ def prepare_args( # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) if stage != "sft" and training_args.predict_with_generate: - raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.") + raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") if training_args.do_train and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True while training.") if training_args.do_predict and (not training_args.predict_with_generate): - raise ValueError("Please enable `predict_with_generate` for saving model predictions.") + raise ValueError("Please enable `predict_with_generate` to save model predictions.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") @@ -257,6 +273,14 @@ def prepare_args( training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + if model_args.quantization_bit is not None: + if training_args.fp16: + finetuning_args.compute_dtype = torch.float16 + elif training_args.bf16: + finetuning_args.compute_dtype = torch.bfloat16 + else: + finetuning_args.compute_dtype = torch.float32 + # Log on each process the small summary: logger.info( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" diff --git a/src/utils/config.py b/src/utils/config.py index 3d9af49b..be03478d 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,5 +1,6 @@ import os import json +import torch from typing import List, Literal, Optional from dataclasses import asdict, dataclass, field @@ -207,6 +208,10 @@ class FinetuningArguments: LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} ) + compute_dtype: Optional[torch.dtype] = field( + default=None, + metadata={"help": "Used in quantization configs. Do not specify this argument manually."} + ) def __post_init__(self): if isinstance(self.lora_target, str):