fix ppo train and dpo eval
This commit is contained in:
parent
11c1e1e157
commit
01260d9754
|
@ -75,6 +75,14 @@ class FinetuningArguments:
|
|||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
||||
)
|
||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
|
@ -91,7 +99,7 @@ class FinetuningArguments:
|
|||
if isinstance(self.additional_target, str):
|
||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -44,7 +44,7 @@ class ModelArguments:
|
|||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
|
@ -83,3 +83,6 @@ class ModelArguments:
|
|||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
|
|
@ -36,8 +36,8 @@ def init_adapter(
|
|||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
|
@ -60,11 +60,11 @@ def init_adapter(
|
|||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
checkpoint_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
if is_trainable and finetuning_args.resume_lora_training:
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
|
@ -75,10 +75,10 @@ def init_adapter(
|
|||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
|
|
|
@ -15,6 +15,7 @@ from transformers import (
|
|||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
|
@ -55,9 +56,6 @@ def load_model_and_tokenizer(
|
|||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
|
@ -212,8 +210,11 @@ def load_model_and_tokenizer(
|
|||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if getattr(model, "is_peft_model", False):
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from copy import deepcopy
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
|
@ -34,9 +38,23 @@ def run_dpo(
|
|||
)
|
||||
|
||||
# Create reference model
|
||||
ref_model = None
|
||||
if not isinstance(model, PeftModel):
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
if finetuning_args.dpo_ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
||||
elif training_args.do_train:
|
||||
if isinstance(model, PeftModel):
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from the model itself.")
|
||||
else:
|
||||
ref_model = model
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
|
@ -68,6 +86,11 @@ def run_dpo(
|
|||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
|
Loading…
Reference in New Issue