remove PeftTrainer

This commit is contained in:
hiyouga 2023-09-10 22:23:23 +08:00
parent baac22f4f4
commit b218c271ed
17 changed files with 75 additions and 259 deletions

View File

@ -1,5 +1,5 @@
torch>=1.13.1
transformers>=4.29.1
transformers>=4.30.0
datasets>=2.12.0
accelerate>=0.21.0
peft==0.4.0

View File

@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.training_args import TrainingArguments
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
@ -17,6 +19,24 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
class LogCallback(TrainerCallback):
def __init__(self, runner=None):

View File

@ -2,10 +2,6 @@ IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"]

View File

@ -192,6 +192,7 @@ class FlashRotaryEmbedding(torch.nn.Module):
else:
assert False
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
@ -204,26 +205,7 @@ class LlamaMLP(nn.Module):
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -301,27 +283,9 @@ class LlamaAttention(nn.Module):
else:
past_len = 0
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
q = torch.cat(q, dim=-1)
k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
k = torch.cat(k, dim=-1)
v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
v = torch.cat(v, dim=-1)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
@ -377,12 +341,7 @@ class LlamaAttention(nn.Module):
attn_output = attn_output.reshape(bsz, q_len, h_size)
attn_weights = attn_outputs[2] if output_attentions else None
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
@ -703,12 +662,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None

View File

@ -1,49 +1,21 @@
import os
import torch
from typing import Dict
from transformers.trainer import WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if os.path.exists(weights_file):
model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
else:
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
if not os.path.exists(valuehead_file):
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@ -11,7 +11,6 @@ from peft import (
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.tuner.core.utils import find_all_linear_modules
if TYPE_CHECKING:
@ -53,9 +52,6 @@ def init_adapter(
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None

View File

@ -38,7 +38,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
check_min_version("4.29.1")
check_min_version("4.30.0")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
@ -78,7 +78,7 @@ def load_model_and_tokenizer(
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
@ -197,6 +197,7 @@ def load_model_and_tokenizer(
# Prepare model with valuehead for RLHF
if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
model._keys_to_ignore_on_save = None
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")

View File

@ -1,118 +0,0 @@
import os
import torch
from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__)
class PeftModelMixin:
r"""
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
"""
def __init__(self) -> None: # for type checking
self.model: PreTrainedModel = None
self.tokenizer: "PreTrainedTokenizer" = None
self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
self.state: "TrainerState" = None
raise AssertionError("Mixin should not be initialized.")
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r"""
Saves trainable parameters as model checkpoint.
This function will only be executed at the process zero.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model_unwrapped = unwrap_model(model)
if isinstance(model_unwrapped, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict()
v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model_unwrapped.pretrained_model
model_unwrapped = unwrap_model(model)
state_dict = state_dict or get_state_dict(model)
if not isinstance(model, (PeftModel, PreTrainedModel)):
if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)):
model_unwrapped.config.use_cache = True
model_unwrapped.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model_unwrapped.config.use_cache = False
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.config.use_cache = True
model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model.config.use_cache = False
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try:
self.tokenizer.save_pretrained(output_dir)
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):
r"""
Loads trainable parameters from model checkpoint.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
"""
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper):
model.v_head.load_state_dict(torch.load(
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
))
model = model.pretrained_model
if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@ -6,18 +6,16 @@ from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
beta: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self.beta = beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, model=model, **kwargs)

View File

@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -37,10 +37,10 @@ def run_dpo(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,

View File

@ -4,27 +4,25 @@ import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import GenerationConfig, TrainerState, TrainerControl
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
from llmtuner.hparams import GeneratingArguments
logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def __init__(
self,
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype,
**kwargs
):
@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0]
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype
self.state = TrainerState()
self.control = TrainerControl()
@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
self.log_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(
@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"""
if self.args.should_save:
self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)

View File

@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -61,11 +62,10 @@ def run_ppo(
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
ppo_trainer = CustomPPOTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks,
callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,

View File

@ -2,12 +2,11 @@
import math
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForLanguageModeling
from transformers import DataCollatorForLanguageModeling, Trainer
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@ -27,8 +26,7 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
trainer = PeftTrainer(
finetuning_args=finetuning_args,
trainer = Trainer(
model=model,
args=training_args,
tokenizer=tokenizer,

View File

@ -2,9 +2,9 @@ import os
import json
import torch
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers import Trainer
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
@ -14,7 +14,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer):
class PairwiseTrainer(Trainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
"""

View File

@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
from llmtuner.tuner.rm.trainer import PairwiseTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -33,13 +34,12 @@ def run_rm(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
trainer = PairwiseTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
callbacks=callbacks + [SavePeftModelCallback()],
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args)
)

View File

@ -4,10 +4,10 @@ import torch
import numpy as np
import torch.nn as nn
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers import Seq2SeqTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
@ -16,7 +16,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class Seq2SeqPeftTrainer(PeftTrainer):
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
r"""
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
"""

View File

@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
@ -45,8 +45,7 @@ def run_sft(
training_args = Seq2SeqTrainingArguments(**training_args_dict)
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args,
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,