fix ppo train and dpo eval

This commit is contained in:
hiyouga 2023-11-07 22:48:51 +08:00
parent 11c1e1e157
commit 01260d9754
5 changed files with 56 additions and 21 deletions

View File

@ -75,6 +75,14 @@ class FinetuningArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_ref_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the reference model used for the DPO training."}
)
dpo_ref_model_checkpoint: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
@ -91,7 +99,7 @@ class FinetuningArguments:
if isinstance(self.additional_target, str):
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`."""

View File

@ -1,5 +1,5 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
@ -44,7 +44,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field(
default=False,
@ -83,3 +83,6 @@ class ModelArguments:
if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@ -36,8 +36,8 @@ def init_adapter(
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.info("Checkpoint is not found at evaluation, load the original model.")
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
@ -60,11 +60,11 @@ def init_adapter(
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
latest_checkpoint = None
checkpoint_to_resume = None
if model_args.checkpoint_dir is not None:
if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
if is_trainable and finetuning_args.resume_lora_training:
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
@ -75,10 +75,10 @@ def init_adapter(
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if latest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if checkpoint_to_resume is not None: # resume lora training
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else:

View File

@ -15,6 +15,7 @@ from transformers import (
)
from transformers.models.llama import modeling_llama as LlamaModule
from transformers.utils.versions import require_version
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
try:
@ -55,9 +56,6 @@ def load_model_and_tokenizer(
Support both training and inference.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
config_kwargs = {
"trust_remote_code": True,
@ -212,8 +210,11 @@ def load_model_and_tokenizer(
if stage == "ppo": # load reward model
logger.info("Load reward model from {}".format(model_args.reward_model))
if getattr(model, "is_peft_model", False):
if isinstance(model.pretrained_model, PeftModel):
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name:
param.data = param.data.to(torch.float32) # trainable params should in fp32
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
# Prepare model for inference

View File

@ -1,20 +1,24 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.hparams import DataArguments, FinetuningArguments
logger = get_logger(__name__)
def run_dpo(
@ -34,9 +38,23 @@ def run_dpo(
)
# Create reference model
ref_model = None
if not isinstance(model, PeftModel):
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
if finetuning_args.dpo_ref_model is not None:
ref_model_args_dict = model_args.to_dict()
ref_model_args_dict.update(dict(
model_name_or_path=finetuning_args.dpo_ref_model,
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
))
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
elif training_args.do_train:
if isinstance(model, PeftModel):
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
logger.info("Created reference model from the model itself.")
else:
ref_model = model
# Update arguments
training_args_dict = training_args.to_dict()
@ -68,6 +86,11 @@ def run_dpo(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if id(model) == id(ref_model): # unable to compute rewards without a reference model
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
remove_keys = [key for key in metrics.keys() if "rewards" in key]
for key in remove_keys:
metrics.pop(key)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)