From d4be857e23c74ed65e06903e19da6f18f15d9e30 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 12 Sep 2023 16:10:10 +0800 Subject: [PATCH] fix #762 #814 --- README.md | 2 +- README_zh.md | 2 +- requirements.txt | 2 +- src/llmtuner/extras/callbacks.py | 10 ++++++---- .../extras/{models => patches}/__init__.py | 0 .../extras/{models => patches}/flash_llama.py | 11 +++++----- src/llmtuner/tuner/core/adapter.py | 2 ++ src/llmtuner/tuner/core/loader.py | 20 +++++++++---------- 8 files changed, 26 insertions(+), 23 deletions(-) rename src/llmtuner/extras/{models => patches}/__init__.py (100%) rename src/llmtuner/extras/{models => patches}/flash_llama.py (97%) diff --git a/README.md b/README.md index ebb9b779..a3c2d81d 100644 --- a/README.md +++ b/README.md @@ -470,7 +470,7 @@ If this work is helpful, please kindly cite as: ## Acknowledgement -This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works. +This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora), [FastChat](https://github.com/lm-sys/FastChat) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works. ## Star History diff --git a/README_zh.md b/README_zh.md index 243823e1..f603780f 100644 --- a/README_zh.md +++ b/README_zh.md @@ -469,7 +469,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ ## 致谢 -本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。 +本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora)、[FastChat](https://github.com/lm-sys/FastChat) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。 ## Star History diff --git a/requirements.txt b/requirements.txt index 92ef9d76..1d36fd33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ torch>=1.13.1 transformers>=4.30.0 datasets>=2.12.0 accelerate>=0.21.0 -peft==0.4.0 +peft>=0.4.0 trl>=0.7.1 scipy sentencepiece diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index a3ff8fee..beb13bfa 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -5,9 +5,7 @@ from typing import TYPE_CHECKING from datetime import timedelta from transformers import TrainerCallback -from transformers.trainer_callback import TrainerControl, TrainerState from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR -from transformers.training_args import TrainingArguments from llmtuner.extras.constants import LOG_FILE_NAME from llmtuner.extras.logging import get_logger @@ -27,14 +25,18 @@ class SavePeftModelCallback(TrainerCallback): """ if args.should_save: output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) - getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir) + model = kwargs.pop("model") + if getattr(model, "is_peft_model", False): + getattr(model, "pretrained_model").save_pretrained(output_dir) def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" Event called at the end of training. """ if args.should_save: - getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir) + model = kwargs.pop("model") + if getattr(model, "is_peft_model", False): + getattr(model, "pretrained_model").save_pretrained(args.output_dir) class LogCallback(TrainerCallback): diff --git a/src/llmtuner/extras/models/__init__.py b/src/llmtuner/extras/patches/__init__.py similarity index 100% rename from src/llmtuner/extras/models/__init__.py rename to src/llmtuner/extras/patches/__init__.py diff --git a/src/llmtuner/extras/models/flash_llama.py b/src/llmtuner/extras/patches/flash_llama.py similarity index 97% rename from src/llmtuner/extras/models/flash_llama.py rename to src/llmtuner/extras/patches/flash_llama.py index d6c078bd..1d6ee66d 100644 --- a/src/llmtuner/extras/models/flash_llama.py +++ b/src/llmtuner/extras/patches/flash_llama.py @@ -230,7 +230,8 @@ class LlamaAttention(torch.nn.Module): new_len = past_len+q.size(1) if new_len > past_kv.size(1): past_kv = torch.cat( - [past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1 + [past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], + dim=1 ) past_kv[:, past_len:new_len] = kv kv = past_kv[:, :new_len] @@ -248,20 +249,18 @@ class LlamaAttention(torch.nn.Module): attn_outputs = flash_attn_varlen_kvpacked_func( unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=0.0, softmax_scale=1.0/self.norm_factor, + dropout_p=0.0, softmax_scale=1.0 / self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions ) attn_output = attn_outputs[0] if output_attentions else attn_outputs - attn_output = pad_input( - attn_output, indices_q, bsz, q_len - ).reshape(bsz, q_len, h_size) + attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size) attn_weights = attn_outputs[2] if output_attentions else None else: # no padding tokens, more efficient attn_outputs = flash_attn_kvpacked_func( - q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor, + q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor, causal=(not has_layer_past), return_attn_probs=output_attentions ) attn_output = attn_outputs[0] if output_attentions else attn_outputs diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 64d1f485..6a9e454e 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -92,6 +92,8 @@ def init_adapter( target_modules=target_modules ) model = get_peft_model(model, lora_config) + if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 + model.base_model.peft_config = model.peft_config if model_args.checkpoint_dir is not None: logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 656d3918..95c1eee9 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -4,7 +4,6 @@ import torch from types import MethodType from typing import TYPE_CHECKING, Literal, Optional, Tuple -import transformers from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -41,7 +40,7 @@ logger = get_logger(__name__) check_min_version("4.30.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") -require_version("peft==0.4.0", "To fix: pip install peft==0.4.0") +require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1") @@ -133,11 +132,11 @@ def load_model_and_tokenizer( # Set flash attention if model_args.flash_attn and getattr(config, "model_type", None) == "llama": import transformers.models.llama.modeling_llama as LlamaModule - from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask - LlamaModule.LlamaRMSNorm = LlamaRMSNorm - LlamaModule.LlamaAttention = LlamaAttention - LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask - if not hasattr(config, "num_key_value_heads"): + import llmtuner.extras.patches.flash_llama as FlashLlama + LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm + LlamaModule.LlamaAttention = FlashLlama.LlamaAttention + LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask + if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads")) if getattr(config, "pretraining_tp", 1) != 1: setattr(config, "pretraining_tp", 1) @@ -199,11 +198,11 @@ def load_model_and_tokenizer( # Prepare model with valuehead for RLHF if stage == "rm" or stage == "ppo": - model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model) + model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model._keys_to_ignore_on_save = None reset_logging() if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model - logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") + logger.warning("Only the last checkpoint containing valuehead will be loaded.") if load_valuehead_params(model, model_args.checkpoint_dir[-1]): model.v_head.load_state_dict({ "summary.weight": getattr(model, "reward_head_weight"), @@ -212,7 +211,8 @@ def load_model_and_tokenizer( if stage == "ppo": # load reward model logger.info("Load reward model from {}".format(model_args.reward_model)) - model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) + if getattr(model, "is_peft_model", False): + model.pretrained_model.load_adapter(model_args.reward_model, "reward") assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." # Prepare model for inference