diff --git a/README.md b/README.md index 0b7dd9a2..1db23423 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ ## Changelog +[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model. + [23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) [23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. @@ -111,7 +113,7 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \ ```bash CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --do_train \ --dataset wiki_demo \ --finetuning_type lora \ @@ -132,11 +134,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \ ```bash CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ --finetuning_type lora \ - --checkpoint_dir path_to_pt_checkpoint \ --output_dir path_to_sft_checkpoint \ --overwrite_cache \ --per_device_train_batch_size 4 \ @@ -146,7 +147,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ --save_steps 1000 \ --learning_rate 5e-5 \ --num_train_epochs 3.0 \ - --resume_lora_training False \ --plot_loss \ --fp16 ``` @@ -155,11 +155,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ ```bash CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --do_train \ --dataset comparison_gpt4_en \ --finetuning_type lora \ - --checkpoint_dir path_to_pt_checkpoint \ --output_dir path_to_rm_checkpoint \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 4 \ @@ -176,11 +175,11 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \ ```bash CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --do_train \ --dataset alpaca_gpt4_en \ --finetuning_type lora \ - --checkpoint_dir path_to_pt_checkpoint,path_to_sft_checkpoint \ + --checkpoint_dir path_to_sft_checkpoint \ --reward_model path_to_rm_checkpoint \ --output_dir path_to_ppo_checkpoint \ --per_device_train_batch_size 2 \ @@ -205,7 +204,7 @@ accelerate launch src/train_XX.py # arguments (same as above) ```bash CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --do_eval \ --dataset alpaca_gpt4_en \ --checkpoint_dir path_to_checkpoint \ @@ -215,20 +214,20 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ --predict_with_generate ``` -We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation. +We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. ### CLI Demo ```bash python src/cli_demo.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --checkpoint_dir path_to_checkpoint ``` ### Web Demo ```bash python src/web_demo.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --checkpoint_dir path_to_checkpoint ``` @@ -236,7 +235,7 @@ python src/web_demo.py \ ```bash python src/export_model.py \ - --model_name_or_path path_to_llama_model \ + --model_name_or_path path_to_your_model \ --checkpoint_dir path_to_checkpoint \ --output_dir path_to_export ``` @@ -249,6 +248,8 @@ Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/ma Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models. +Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model. + ## Citation If this work is helpful, please cite as: diff --git a/src/utils/common.py b/src/utils/common.py index 094516ea..5d85396b 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -29,7 +29,7 @@ from peft import ( get_peft_model ) -from peft.utils import CONFIG_NAME +from peft.utils import CONFIG_NAME, WEIGHTS_NAME from trl import AutoModelForCausalLMWithValueHead @@ -103,8 +103,10 @@ def _init_adapter( lastest_checkpoint = None if model_args.checkpoint_dir is not None: - assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ - "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." + if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \ + not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): + raise ValueError("The given checkpoint may be not a LoRA checkpoint, \ + please specify `--finetuning_type full/freeze` instead.") if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] @@ -170,8 +172,7 @@ def load_pretrained( **config_kwargs ) tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token - if tokenizer.pad_token_id == 64000: - tokenizer.pad_token_id = 0 # for baichuan model (need fix) + tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True @@ -212,7 +213,7 @@ def load_pretrained( low_cpu_mem_usage=True, **config_kwargs ) - model = prepare_model_for_training(model) if is_trainable else model + model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) if stage == "rm" or stage == "ppo": # add value head diff --git a/src/utils/config.py b/src/utils/config.py index 5b919ec6..0778cb7c 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -195,7 +195,8 @@ class FinetuningArguments: default="mlp", metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ LLaMA choices: [\"mlp\", \"self_attn\"], \ - BLOOM choices: [\"mlp\", \"self_attention\"]"} + BLOOM choices: [\"mlp\", \"self_attention\"], \ + Baichuan choices: [\"mlp\", \"self_attn\"]"} ) lora_rank: Optional[int] = field( default=8, @@ -212,8 +213,9 @@ class FinetuningArguments: lora_target: Optional[str] = field( default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ - LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \ - BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"} + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ + BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} ) def __post_init__(self): diff --git a/src/utils/other.py b/src/utils/other.py index 88bf081b..25603c70 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -73,6 +73,7 @@ def get_logits_processor() -> LogitsProcessorList: # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 def prepare_model_for_training( model: PreTrainedModel, + finetuning_type: str, output_embedding_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting @@ -93,13 +94,13 @@ def prepare_model_for_training( model.gradient_checkpointing_enable() model.config.use_cache = False # turn off when gradient checkpointing is enabled - if hasattr(model, output_embedding_layer_name): - output_embedding_layer = getattr(model, output_embedding_layer_name) + if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): + output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) input_dtype = output_embedding_layer.weight.dtype class CastOutputToFloat(torch.nn.Sequential): - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x.to(input_dtype)).to(torch.float32) setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))