support loading lora from hub
This commit is contained in:
parent
0cee6ad67f
commit
0574b590ef
27
README.md
27
README.md
|
@ -9,6 +9,8 @@
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model.
|
||||||
|
|
||||||
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||||
|
|
||||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
||||||
|
@ -111,7 +113,7 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
@ -132,11 +134,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_pt_checkpoint \
|
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
|
@ -146,7 +147,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 5e-5 \
|
--learning_rate 5e-5 \
|
||||||
--num_train_epochs 3.0 \
|
--num_train_epochs 3.0 \
|
||||||
--resume_lora_training False \
|
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
@ -155,11 +155,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_pt_checkpoint \
|
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
|
@ -176,11 +175,11 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_pt_checkpoint,path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
--output_dir path_to_ppo_checkpoint \
|
--output_dir path_to_ppo_checkpoint \
|
||||||
--per_device_train_batch_size 2 \
|
--per_device_train_batch_size 2 \
|
||||||
|
@ -205,7 +204,7 @@ accelerate launch src/train_XX.py # arguments (same as above)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
|
@ -215,20 +214,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
||||||
--predict_with_generate
|
--predict_with_generate
|
||||||
```
|
```
|
||||||
|
|
||||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.
|
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||||
|
|
||||||
### CLI Demo
|
### CLI Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
### Web Demo
|
### Web Demo
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -236,7 +235,7 @@ python src/web_demo.py \
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
python src/export_model.py \
|
||||||
--model_name_or_path path_to_llama_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
```
|
```
|
||||||
|
@ -249,6 +248,8 @@ Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/ma
|
||||||
|
|
||||||
Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
|
Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
|
||||||
|
|
||||||
|
Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If this work is helpful, please cite as:
|
If this work is helpful, please cite as:
|
||||||
|
|
|
@ -29,7 +29,7 @@ from peft import (
|
||||||
get_peft_model
|
get_peft_model
|
||||||
)
|
)
|
||||||
|
|
||||||
from peft.utils import CONFIG_NAME
|
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
@ -103,8 +103,10 @@ def _init_adapter(
|
||||||
lastest_checkpoint = None
|
lastest_checkpoint = None
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
||||||
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||||
|
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
||||||
|
please specify `--finetuning_type full/freeze` instead.")
|
||||||
|
|
||||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
|
@ -170,8 +172,7 @@ def load_pretrained(
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||||
if tokenizer.pad_token_id == 64000:
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
||||||
tokenizer.pad_token_id = 0 # for baichuan model (need fix)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||||
is_mergeable = True
|
is_mergeable = True
|
||||||
|
@ -212,7 +213,7 @@ def load_pretrained(
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
model = prepare_model_for_training(model) if is_trainable else model
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
if stage == "rm" or stage == "ppo": # add value head
|
||||||
|
|
|
@ -195,7 +195,8 @@ class FinetuningArguments:
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM choices: [\"mlp\", \"self_attention\"]"}
|
BLOOM choices: [\"mlp\", \"self_attention\"], \
|
||||||
|
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
|
@ -212,8 +213,9 @@ class FinetuningArguments:
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default="q_proj,v_proj",
|
default="q_proj,v_proj",
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
|
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
|
@ -73,6 +73,7 @@ def get_logits_processor() -> LogitsProcessorList:
|
||||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
|
finetuning_type: str,
|
||||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
|
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
|
||||||
|
@ -93,13 +94,13 @@ def prepare_model_for_training(
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if hasattr(model, output_embedding_layer_name):
|
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||||
output_embedding_layer = getattr(model, output_embedding_layer_name)
|
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||||
input_dtype = output_embedding_layer.weight.dtype
|
input_dtype = output_embedding_layer.weight.dtype
|
||||||
|
|
||||||
class CastOutputToFloat(torch.nn.Sequential):
|
class CastOutputToFloat(torch.nn.Sequential):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||||
|
|
||||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||||
|
|
Loading…
Reference in New Issue