support QLoRA
This commit is contained in:
parent
1bd13d7ca1
commit
3b9eee8cd2
11
README.md
11
README.md
|
@ -9,12 +9,14 @@
|
|||
|
||||
## Changelog
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||
|
||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B)
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
|
@ -22,12 +24,15 @@
|
|||
- Full-parameter training
|
||||
- Partial-parameter training
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
||||
- Full-parameter training
|
||||
- Partial-parameter training
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
|
@ -209,6 +214,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
|||
--predict_with_generate
|
||||
```
|
||||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
```bash
|
||||
|
|
|
@ -11,7 +11,8 @@ from transformers import (
|
|||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments
|
||||
Seq2SeqTrainingArguments,
|
||||
BitsAndBytesConfig
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
|
@ -167,12 +168,27 @@ def load_pretrained(
|
|||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
|
||||
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6.0
|
||||
)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
is_mergeable = False
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
|
@ -183,7 +199,7 @@ def load_pretrained(
|
|||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # the model weights are float16 type
|
||||
torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
**config_kwargs
|
||||
)
|
||||
|
@ -237,13 +253,13 @@ def prepare_args(
|
|||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
if stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.")
|
||||
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_predict and (not training_args.predict_with_generate):
|
||||
raise ValueError("Please enable `predict_with_generate` for saving model predictions.")
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
@ -257,6 +273,14 @@ def prepare_args(
|
|||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
finetuning_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
finetuning_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
finetuning_args.compute_dtype = torch.float32
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import List, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
@ -207,6 +208,10 @@ class FinetuningArguments:
|
|||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str):
|
||||
|
|
Loading…
Reference in New Issue