support BLOOM models
This commit is contained in:
parent
a72492e649
commit
740a5daf56
70
README.md
70
README.md
|
@ -5,9 +5,65 @@
|
|||
![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)
|
||||
![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B)
|
||||
|
||||
## Supported Training Approach
|
||||
|
||||
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- Full-parameter training
|
||||
- Selected-parameter training
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
||||
- Full-parameter training
|
||||
- Selected-parameter training
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
- For pre-training:
|
||||
- [Wiki Demo](data/wiki_demo.txt)
|
||||
- For supervised fine-tuning:
|
||||
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat](https://github.com/thunlp/UltraChat)
|
||||
- For reward model training:
|
||||
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for details.
|
||||
|
||||
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
|
||||
|
||||
```bash
|
||||
pip install --upgrade huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Requirement
|
||||
|
||||
- Python 3.8+ and PyTorch 1.13.1
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- protobuf, cpm_kernels and sentencepiece
|
||||
- jieba, rouge_chinese and nltk (used at evaluation)
|
||||
|
@ -36,10 +92,10 @@ pip install -r requirements.txt
|
|||
### LLaMA Weights Preparation
|
||||
|
||||
1. Download the weights of the LLaMA models.
|
||||
2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)
|
||||
2. Convert them to HF format using the following command.
|
||||
|
||||
```python
|
||||
python convert_llama_weights_to_hf.py \
|
||||
```bash
|
||||
python -m transformers.models.llama.convert_llama_weights_to_hf \
|
||||
--input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model
|
||||
```
|
||||
|
||||
|
@ -177,7 +233,11 @@ python src/export_model.py \
|
|||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE). Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA model.
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA models.
|
||||
|
||||
Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
|
||||
|
||||
## Citation
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
# coding=utf-8
|
||||
# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
|
||||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
|
@ -12,10 +12,11 @@ def main():
|
|||
|
||||
parser = HfArgumentParser(ModelArguments)
|
||||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count())
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
@ -47,7 +48,7 @@ def main():
|
|||
return response, history
|
||||
|
||||
history = []
|
||||
print("欢迎使用 LLaMA-7B 模型,输入内容即可对话,clear清空对话历史,stop终止程序")
|
||||
print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name))
|
||||
while True:
|
||||
try:
|
||||
query = input("\nInput: ")
|
||||
|
@ -65,7 +66,7 @@ def main():
|
|||
continue
|
||||
|
||||
response, history = predict(query, history)
|
||||
print("LLaMA-7B:", response)
|
||||
print("{}:".format(model_name), response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Exports the fine-tuned LLaMA model.
|
||||
# Exports the fine-tuned model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements parameter-efficient PPO training of fine-tuned LLaMA.
|
||||
# Implements parameter-efficient PPO training of fine-tuned models.
|
||||
# This code is inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
|
||||
|
@ -15,8 +15,8 @@ from utils import (
|
|||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
PPOTrainerForLLaMA,
|
||||
DynamicDataCollatorWithPadding,
|
||||
PPOPeftTrainer,
|
||||
LogCallback,
|
||||
plot_loss
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ def main():
|
|||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
|
@ -52,7 +52,7 @@ def main():
|
|||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOTrainerForLLaMA(
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[LogCallback()],
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements several parameter-efficient pre-training method for LLaMA.
|
||||
# Implements several parameter-efficient pre-training method.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
|
@ -10,7 +10,7 @@ from utils import (
|
|||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
DynamicDataCollatorWithPadding,
|
||||
PeftTrainer,
|
||||
LogCallback,
|
||||
plot_loss
|
||||
|
@ -24,7 +24,7 @@ def main():
|
|||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements parameter-efficient training of a reward model based on LLaMA.
|
||||
# Implements parameter-efficient training of reward models.
|
||||
# This code is inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
@ -10,8 +10,8 @@ from utils import (
|
|||
prepare_data,
|
||||
load_pretrained,
|
||||
preprocess_data,
|
||||
PairwiseDataCollatorForLLaMA,
|
||||
PairwiseTrainerForLLaMA,
|
||||
PairwiseDataCollatorWithPadding,
|
||||
PairwisePeftTrainer,
|
||||
LogCallback,
|
||||
plot_loss
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ def main():
|
|||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
|
@ -38,7 +38,7 @@ def main():
|
|||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwiseTrainerForLLaMA(
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements several parameter-efficient supervised fine-tuning method for LLaMA.
|
||||
# Implements several parameter-efficient supervised fine-tuning method.
|
||||
# This code is inspired by
|
||||
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
|
@ -9,8 +9,8 @@ from utils import (
|
|||
prepare_args,
|
||||
prepare_data,
|
||||
preprocess_data,
|
||||
DataCollatorForLLaMA,
|
||||
Seq2SeqTrainerForLLaMA,
|
||||
DynamicDataCollatorWithPadding,
|
||||
Seq2SeqPeftTrainer,
|
||||
ComputeMetrics,
|
||||
LogCallback,
|
||||
get_logits_processor,
|
||||
|
@ -25,7 +25,7 @@ def main():
|
|||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
|
@ -44,7 +44,7 @@ def main():
|
|||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqTrainerForLLaMA(
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
@ -5,13 +5,13 @@ from .common import (
|
|||
preprocess_data
|
||||
)
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
from .data_collator import DynamicDataCollatorWithPadding
|
||||
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
|
||||
from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
|
||||
from .ppo import PPOTrainerForLLaMA
|
||||
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
||||
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
|
||||
from .ppo import PPOPeftTrainer
|
||||
|
||||
from .config import ModelArguments
|
||||
from .other import auto_configure_device_map, get_logits_processor, plot_loss
|
||||
from .other import get_logits_processor, plot_loss
|
||||
|
|
|
@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple
|
|||
|
||||
import transformers
|
||||
from transformers import (
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Seq2SeqTrainingArguments
|
||||
)
|
||||
|
@ -151,7 +151,7 @@ def load_pretrained(
|
|||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with LoRA method."
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side="left"
|
||||
|
@ -173,13 +173,13 @@ def load_pretrained(
|
|||
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
|
||||
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype=torch.float16, # the llama weights are float16 type
|
||||
torch_dtype=torch.float16, # the model weights are float16 type
|
||||
**config_kwargs
|
||||
)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
|
@ -245,7 +245,7 @@ def prepare_args(
|
|||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
|
|
|
@ -12,6 +12,12 @@ class DatasetAttr:
|
|||
file_name: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.dataset_name is not None:
|
||||
return self.dataset_name
|
||||
else:
|
||||
return self.file_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
|
@ -161,9 +167,11 @@ class FinetuningArguments:
|
|||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM choices: [\"mlp\", \"self_attention\"]"}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
|
@ -171,7 +179,7 @@ class FinetuningArguments:
|
|||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
|
@ -179,7 +187,9 @@ class FinetuningArguments:
|
|||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"mlp\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"dense\", \"mlp\"]"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -191,11 +201,7 @@ class FinetuningArguments:
|
|||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
if self.name_module_trainable == "mlp":
|
||||
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
|
||||
elif self.name_module_trainable == "qkv":
|
||||
self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
|
||||
for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
|
|
|
@ -9,9 +9,9 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
|||
from .other import IGNORE_INDEX
|
||||
|
||||
|
||||
class DataCollatorForLLaMA(DataCollatorWithPadding):
|
||||
class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for LLaMA. It is capable of dynamically padding for batched data.
|
||||
Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -75,7 +75,7 @@ def prepare_model_for_training(
|
|||
model: PreTrainedModel,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting
|
||||
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
|
||||
) -> PreTrainedModel:
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
|
@ -143,29 +143,6 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
|
|||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
|
||||
|
||||
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
||||
r"""
|
||||
Configures device map for LLaMA.
|
||||
|
||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
|
||||
"""
|
||||
num_layers = 28
|
||||
layers_per_gpu = 30 / num_gpus
|
||||
device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0}
|
||||
added_layers = 2
|
||||
target_gpu = 0
|
||||
|
||||
for i in range(num_layers):
|
||||
if added_layers >= layers_per_gpu:
|
||||
target_gpu += 1
|
||||
added_layers = 0
|
||||
assert target_gpu < num_gpus
|
||||
device_map[f"model.layers.{i}"] = target_gpu
|
||||
added_layers += 1
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
||||
"""
|
||||
EMA implementation according to TensorBoard.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
from .data_collator import DataCollatorForLLaMA
|
||||
from .data_collator import DynamicDataCollatorWithPadding
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
|
||||
|
@ -10,7 +10,7 @@ from .other import get_logger
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
|
||||
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
@ -26,7 +26,7 @@ class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
|
|||
return super().__call__(features)
|
||||
|
||||
|
||||
class PairwiseTrainerForLLaMA(PeftTrainer):
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
|
|
@ -58,7 +58,7 @@ def cast_layernorm_dtype(
|
|||
return model, layer_norm_state_dict
|
||||
|
||||
|
||||
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
|||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
|
||||
# Get response from LLaMA
|
||||
# Get response from model
|
||||
query_tensors: torch.Tensor = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ logger = get_logger(__name__)
|
|||
@dataclass
|
||||
class ComputeMetrics:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
|
||||
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
|
||||
"""
|
||||
|
@ -62,7 +62,7 @@ class ComputeMetrics:
|
|||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
|
||||
class Seq2SeqTrainerForLLaMA(PeftTrainer):
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Implements user interface in browser for LLaMA fine-tuned with PEFT.
|
||||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch
|
|||
import mdtex2html
|
||||
import gradio as gr
|
||||
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained
|
||||
from utils import ModelArguments, load_pretrained
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
@ -17,8 +17,8 @@ parser = HfArgumentParser(ModelArguments)
|
|||
model_args, = parser.parse_args_into_dataclasses()
|
||||
model, tokenizer = load_pretrained(model_args)
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count())
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(model)
|
||||
model = dispatch_model(model, device_map)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
@ -111,7 +111,7 @@ def reset_state():
|
|||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
|
||||
gr.HTML("""<h1 align="center">LLaMA-Efficient-Tuning</h1>""")
|
||||
|
||||
chatbot = gr.Chatbot()
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue