This commit is contained in:
hiyouga 2023-09-12 16:10:10 +08:00
parent 3b306478d4
commit d4be857e23
8 changed files with 26 additions and 23 deletions

View File

@ -470,7 +470,7 @@ If this work is helpful, please kindly cite as:
## Acknowledgement
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works.
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora), [FastChat](https://github.com/lm-sys/FastChat) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works.
## Star History

View File

@ -469,7 +469,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
## 致谢
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora)、[FastChat](https://github.com/lm-sys/FastChat) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。
## Star History

View File

@ -2,7 +2,7 @@ torch>=1.13.1
transformers>=4.30.0
datasets>=2.12.0
accelerate>=0.21.0
peft==0.4.0
peft>=0.4.0
trl>=0.7.1
scipy
sentencepiece

View File

@ -5,9 +5,7 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.training_args import TrainingArguments
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
@ -27,14 +25,18 @@ class SavePeftModelCallback(TrainerCallback):
"""
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
model = kwargs.pop("model")
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(output_dir)
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
if args.should_save:
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
model = kwargs.pop("model")
if getattr(model, "is_peft_model", False):
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
class LogCallback(TrainerCallback):

View File

@ -230,7 +230,8 @@ class LlamaAttention(torch.nn.Module):
new_len = past_len+q.size(1)
if new_len > past_kv.size(1):
past_kv = torch.cat(
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
dim=1
)
past_kv[:, past_len:new_len] = kv
kv = past_kv[:, :new_len]
@ -248,20 +249,18 @@ class LlamaAttention(torch.nn.Module):
attn_outputs = flash_attn_varlen_kvpacked_func(
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs
attn_output = pad_input(
attn_output, indices_q, bsz, q_len
).reshape(bsz, q_len, h_size)
attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
else:
# no padding tokens, more efficient
attn_outputs = flash_attn_kvpacked_func(
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
q, kv, dropout_p=0.0, softmax_scale=1.0 / self.norm_factor,
causal=(not has_layer_past), return_attn_probs=output_attentions
)
attn_output = attn_outputs[0] if output_attentions else attn_outputs

View File

@ -92,6 +92,8 @@ def init_adapter(
target_modules=target_modules
)
model = get_peft_model(model, lora_config)
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
model.base_model.peft_config = model.peft_config
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))

View File

@ -4,7 +4,6 @@ import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@ -41,7 +40,7 @@ logger = get_logger(__name__)
check_min_version("4.30.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
@ -133,11 +132,11 @@ def load_model_and_tokenizer(
# Set flash attention
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
import transformers.models.llama.modeling_llama as LlamaModule
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
LlamaModule.LlamaAttention = LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"):
import llmtuner.extras.patches.flash_llama as FlashLlama
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
if getattr(config, "pretraining_tp", 1) != 1:
setattr(config, "pretraining_tp", 1)
@ -199,11 +198,11 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),
@ -212,7 +211,8 @@ def load_model_and_tokenizer(
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
if getattr(model, "is_peft_model", False):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference