forked from p04798526/LLaMA-Factory-Mirror
parent
3b306478d4
commit
d4be857e23
|
@ -470,7 +470,7 @@ If this work is helpful, please kindly cite as:
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works.
|
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora), [FastChat](https://github.com/lm-sys/FastChat) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works.
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
|
|
@ -469,7 +469,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
|
||||||
## 致谢
|
## 致谢
|
||||||
|
|
||||||
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。
|
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora)、[FastChat](https://github.com/lm-sys/FastChat) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。
|
||||||
|
|
||||||
## Star History
|
## Star History
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ torch>=1.13.1
|
||||||
transformers>=4.30.0
|
transformers>=4.30.0
|
||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft==0.4.0
|
peft>=0.4.0
|
||||||
trl>=0.7.1
|
trl>=0.7.1
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|
|
@ -5,9 +5,7 @@ from typing import TYPE_CHECKING
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.training_args import TrainingArguments
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
@ -27,14 +25,18 @@ class SavePeftModelCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
model = kwargs.pop("model")
|
||||||
|
if getattr(model, "is_peft_model", False):
|
||||||
|
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
model = kwargs.pop("model")
|
||||||
|
if getattr(model, "is_peft_model", False):
|
||||||
|
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
|
@ -230,7 +230,8 @@ class LlamaAttention(torch.nn.Module):
|
||||||
new_len = past_len+q.size(1)
|
new_len = past_len+q.size(1)
|
||||||
if new_len > past_kv.size(1):
|
if new_len > past_kv.size(1):
|
||||||
past_kv = torch.cat(
|
past_kv = torch.cat(
|
||||||
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1
|
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)],
|
||||||
|
dim=1
|
||||||
)
|
)
|
||||||
past_kv[:, past_len:new_len] = kv
|
past_kv[:, past_len:new_len] = kv
|
||||||
kv = past_kv[:, :new_len]
|
kv = past_kv[:, :new_len]
|
||||||
|
@ -253,9 +254,7 @@ class LlamaAttention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
||||||
attn_output = pad_input(
|
attn_output = pad_input(attn_output, indices_q, bsz, q_len).reshape(bsz, q_len, h_size)
|
||||||
attn_output, indices_q, bsz, q_len
|
|
||||||
).reshape(bsz, q_len, h_size)
|
|
||||||
attn_weights = attn_outputs[2] if output_attentions else None
|
attn_weights = attn_outputs[2] if output_attentions else None
|
||||||
|
|
||||||
else:
|
else:
|
|
@ -92,6 +92,8 @@ def init_adapter(
|
||||||
target_modules=target_modules
|
target_modules=target_modules
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
|
||||||
|
model.base_model.peft_config = model.peft_config
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
|
@ -4,7 +4,6 @@ import torch
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||||
|
|
||||||
import transformers
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
@ -41,7 +40,7 @@ logger = get_logger(__name__)
|
||||||
check_min_version("4.30.0")
|
check_min_version("4.30.0")
|
||||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
|
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,11 +132,11 @@ def load_model_and_tokenizer(
|
||||||
# Set flash attention
|
# Set flash attention
|
||||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
||||||
import transformers.models.llama.modeling_llama as LlamaModule
|
import transformers.models.llama.modeling_llama as LlamaModule
|
||||||
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
|
import llmtuner.extras.patches.flash_llama as FlashLlama
|
||||||
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
|
LlamaModule.LlamaRMSNorm = FlashLlama.LlamaRMSNorm
|
||||||
LlamaModule.LlamaAttention = LlamaAttention
|
LlamaModule.LlamaAttention = FlashLlama.LlamaAttention
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = FlashLlama._prepare_decoder_attention_mask
|
||||||
if not hasattr(config, "num_key_value_heads"):
|
if not hasattr(config, "num_key_value_heads"): # for LLaMA-1 models
|
||||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
||||||
if getattr(config, "pretraining_tp", 1) != 1:
|
if getattr(config, "pretraining_tp", 1) != 1:
|
||||||
setattr(config, "pretraining_tp", 1)
|
setattr(config, "pretraining_tp", 1)
|
||||||
|
@ -199,11 +198,11 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage == "rm" or stage == "ppo":
|
if stage == "rm" or stage == "ppo":
|
||||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
model._keys_to_ignore_on_save = None
|
model._keys_to_ignore_on_save = None
|
||||||
reset_logging()
|
reset_logging()
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||||
model.v_head.load_state_dict({
|
model.v_head.load_state_dict({
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
"summary.weight": getattr(model, "reward_head_weight"),
|
||||||
|
@ -212,7 +211,8 @@ def load_model_and_tokenizer(
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
if stage == "ppo": # load reward model
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
if getattr(model, "is_peft_model", False):
|
||||||
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||||
|
|
||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
|
|
Loading…
Reference in New Issue