forked from p04798526/LLaMA-Factory-Mirror
remove PeftTrainer
This commit is contained in:
parent
baac22f4f4
commit
b218c271ed
|
@ -1,5 +1,5 @@
|
|||
torch>=1.13.1
|
||||
transformers>=4.29.1
|
||||
transformers>=4.30.0
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.21.0
|
||||
peft==0.4.0
|
||||
|
|
|
@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
|
|||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import has_length
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
@ -17,6 +19,24 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
return control
|
||||
|
||||
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
return control
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
|
|
|
@ -2,10 +2,6 @@ IGNORE_INDEX = -100
|
|||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
|
|
@ -192,6 +192,7 @@ class FlashRotaryEmbedding(torch.nn.Module):
|
|||
else:
|
||||
assert False
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -204,26 +205,7 @@ class LlamaMLP(nn.Module):
|
|||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
if self.config.pretraining_tp > 1:
|
||||
slice = self.intermediate_size // self.config.pretraining_tp
|
||||
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
|
||||
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
|
||||
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
|
||||
|
||||
gate_proj = torch.cat(
|
||||
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
|
||||
)
|
||||
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
|
||||
|
||||
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
|
||||
down_proj = [
|
||||
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
down_proj = sum(down_proj)
|
||||
else:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
return down_proj
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
@ -301,27 +283,9 @@ class LlamaAttention(nn.Module):
|
|||
else:
|
||||
past_len = 0
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
q = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
k = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
k = torch.cat(k, dim=-1)
|
||||
|
||||
v = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
v = torch.cat(v, dim=-1)
|
||||
|
||||
else:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
@ -377,12 +341,7 @@ class LlamaAttention(nn.Module):
|
|||
attn_output = attn_output.reshape(bsz, q_len, h_size)
|
||||
attn_weights = attn_outputs[2] if output_attentions else None
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
||||
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
@ -703,12 +662,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
loss = None
|
||||
|
|
|
@ -1,49 +1,21 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
from transformers.trainer import WEIGHTS_NAME
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
state_dict: Dict[str, torch.Tensor] = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if not os.path.exists(vhead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]), persistent=False)
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
|
|
@ -11,7 +11,6 @@ from peft import (
|
|||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -53,9 +52,6 @@ def init_adapter(
|
|||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
|
|
|
@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.29.1")
|
||||
check_min_version("4.30.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
|
||||
|
@ -78,7 +78,7 @@ def load_model_and_tokenizer(
|
|||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
@ -197,6 +197,7 @@ def load_model_and_tokenizer(
|
|||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model._keys_to_ignore_on_save = None
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
|
|
|
@ -1,118 +0,0 @@
|
|||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftModelMixin:
|
||||
r"""
|
||||
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None: # for type checking
|
||||
self.model: PreTrainedModel = None
|
||||
self.tokenizer: "PreTrainedTokenizer" = None
|
||||
self.args: "Seq2SeqTrainingArguments" = None
|
||||
self.finetuning_args: "FinetuningArguments" = None
|
||||
self.state: "TrainerState" = None
|
||||
raise AssertionError("Mixin should not be initialized.")
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Saves trainable parameters as model checkpoint.
|
||||
|
||||
This function will only be executed at the process zero.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
model = self.model
|
||||
model_unwrapped = unwrap_model(model)
|
||||
|
||||
if isinstance(model_unwrapped, PreTrainedModelWrapper):
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
v_head_state_dict = {
|
||||
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||
}
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
model = model_unwrapped.pretrained_model
|
||||
model_unwrapped = unwrap_model(model)
|
||||
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
if not isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)):
|
||||
model_unwrapped.config.use_cache = True
|
||||
model_unwrapped.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
model_unwrapped.config.use_cache = False
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||
try:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, copy the files manually.")
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
r"""
|
||||
Loads trainable parameters from model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
|
||||
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||
self.finetuning_args = finetuning_args
|
|
@ -6,18 +6,16 @@ from trl import DPOTrainer
|
|||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
beta: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
|
@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
|||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self.beta = beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
|
|
|
@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
|||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
@ -37,10 +37,10 @@ def run_dpo(
|
|||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
|
|
|
@ -4,27 +4,25 @@ import torch
|
|||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, TrainerState, TrainerControl
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
def __init__(
|
||||
self,
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
callbacks: List["TrainerCallback"],
|
||||
compute_dtype: torch.dtype,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
||||
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
self.compute_dtype = compute_dtype
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
|
@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.log_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(
|
||||
|
@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
|
|
@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
|
|||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
@ -61,11 +62,10 @@ def run_ppo(
|
|||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
@ -27,8 +26,7 @@ def run_pt(
|
|||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
|
|
@ -2,9 +2,9 @@ import os
|
|||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Trainer
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
|
|
@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Optional, List
|
|||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
@ -33,13 +34,12 @@ def run_rm(
|
|||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
|
|
@ -4,10 +4,10 @@ import torch
|
|||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Seq2SeqTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
|
|
@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
|
|||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
@ -45,8 +45,7 @@ def run_sft(
|
|||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
|
Loading…
Reference in New Issue