support QLoRA

This commit is contained in:
hiyouga 2023-06-04 00:08:56 +08:00
parent 1bd13d7ca1
commit 3b9eee8cd2
3 changed files with 48 additions and 12 deletions

View File

@ -9,12 +9,14 @@
## Changelog
[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B)
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
## Supported Training Approaches
@ -22,12 +24,15 @@
- Full-parameter training
- Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter training
- Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets
@ -209,6 +214,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--predict_with_generate
```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.
### CLI Demo
```bash

View File

@ -11,7 +11,8 @@ from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments
Seq2SeqTrainingArguments,
BitsAndBytesConfig
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
@ -167,12 +168,27 @@ def load_pretrained(
# Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_8bit"] = True
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
else:
raise NotImplementedError
is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -183,7 +199,7 @@ def load_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=torch.float16, # the model weights are float16 type
torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True,
**config_kwargs
)
@ -237,13 +253,13 @@ def prepare_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -257,6 +273,14 @@ def prepare_args(
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16:
finetuning_args.compute_dtype = torch.float16
elif training_args.bf16:
finetuning_args.compute_dtype = torch.bfloat16
else:
finetuning_args.compute_dtype = torch.float32
# Log on each process the small summary:
logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"

View File

@ -1,5 +1,6 @@
import os
import json
import torch
from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field
@ -207,6 +208,10 @@ class FinetuningArguments:
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
def __post_init__(self):
if isinstance(self.lora_target, str):