From 01260d975477ebb8570933a1bd7f547b4dba607f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Nov 2023 22:48:51 +0800 Subject: [PATCH] fix ppo train and dpo eval --- src/llmtuner/hparams/finetuning_args.py | 10 +++++++- src/llmtuner/hparams/model_args.py | 9 ++++--- src/llmtuner/tuner/core/adapter.py | 16 ++++++------ src/llmtuner/tuner/core/loader.py | 9 ++++--- src/llmtuner/tuner/dpo/workflow.py | 33 +++++++++++++++++++++---- 5 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index cf7608e0..82648ef6 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -75,6 +75,14 @@ class FinetuningArguments: default=0.1, metadata={"help": "The beta parameter for the DPO loss."} ) + dpo_ref_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reference model used for the DPO training."} + ) + dpo_ref_model_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."} + ) upcast_layernorm: Optional[bool] = field( default=False, metadata={"help": "Whether to upcast the layernorm weights in fp32."} @@ -91,7 +99,7 @@ class FinetuningArguments: if isinstance(self.additional_target, str): self.additional_target = [target.strip() for target in self.additional_target.split(",")] - assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." + assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index e14f55de..62404d9e 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -1,5 +1,5 @@ -from typing import Literal, Optional -from dataclasses import dataclass, field +from typing import Any, Dict, Literal, Optional +from dataclasses import asdict, dataclass, field @dataclass @@ -44,7 +44,7 @@ class ModelArguments: ) checkpoint_dir: Optional[str] = field( default=None, - metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."} ) flash_attn: Optional[bool] = field( default=False, @@ -83,3 +83,6 @@ class ModelArguments: if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 4c2984b1..8a771567 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -36,8 +36,8 @@ def init_adapter( Note that the trainable parameters must be cast to float32. """ - if finetuning_args.finetuning_type == "none" and is_trainable: - raise ValueError("You cannot use finetuning_type=none while training.") + if (not is_trainable) and model_args.checkpoint_dir is None: + logger.info("Checkpoint is not found at evaluation, load the original model.") if finetuning_args.finetuning_type == "full" and is_trainable: logger.info("Fine-tuning method: Full") @@ -60,11 +60,11 @@ def init_adapter( if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") - latest_checkpoint = None + checkpoint_to_resume = None if model_args.checkpoint_dir is not None: - if is_trainable and finetuning_args.resume_lora_training: # continually fine-tuning - checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] + if is_trainable and finetuning_args.resume_lora_training: + checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir @@ -75,10 +75,10 @@ def init_adapter( if len(checkpoints_to_merge) > 0: logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) - if latest_checkpoint is not None: # resume lora training or quantized inference - model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) + if checkpoint_to_resume is not None: # resume lora training + model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable) - if is_trainable and latest_checkpoint is None: # create new lora weights while training + if is_trainable and checkpoint_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": target_modules = find_all_linear_modules(model, model_args.quantization_bit) else: diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 2931f087..80d2c658 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -15,6 +15,7 @@ from transformers import ( ) from transformers.models.llama import modeling_llama as LlamaModule from transformers.utils.versions import require_version +from peft import PeftModel from trl import AutoModelForCausalLMWithValueHead try: @@ -55,9 +56,6 @@ def load_model_and_tokenizer( Support both training and inference. """ - if (not is_trainable) and model_args.checkpoint_dir is None: - logger.warning("Checkpoint is not found at evaluation, load the original model.") - finetuning_args = FinetuningArguments(finetuning_type="none") config_kwargs = { "trust_remote_code": True, @@ -212,8 +210,11 @@ def load_model_and_tokenizer( if stage == "ppo": # load reward model logger.info("Load reward model from {}".format(model_args.reward_model)) - if getattr(model, "is_peft_model", False): + if isinstance(model.pretrained_model, PeftModel): model.pretrained_model.load_adapter(model_args.reward_model, "reward") + for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 + if "default" in name: + param.data = param.data.to(torch.float32) # trainable params should in fp32 assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded." # Prepare model for inference diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index 0eef489f..240d34c5 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -1,20 +1,24 @@ # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py -from copy import deepcopy from peft import PeftModel from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.logging import get_logger from llmtuner.extras.ploting import plot_loss +from llmtuner.hparams import ModelArguments from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding from llmtuner.tuner.dpo.trainer import CustomDPOTrainer if TYPE_CHECKING: from transformers import TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments + from llmtuner.hparams import DataArguments, FinetuningArguments + + +logger = get_logger(__name__) def run_dpo( @@ -34,9 +38,23 @@ def run_dpo( ) # Create reference model - ref_model = None - if not isinstance(model, PeftModel): - ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") + if finetuning_args.dpo_ref_model is not None: + ref_model_args_dict = model_args.to_dict() + ref_model_args_dict.update(dict( + model_name_or_path=finetuning_args.dpo_ref_model, + checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint + )) + ref_model_args = ModelArguments(**ref_model_args_dict) + ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft") + logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model)) + elif training_args.do_train: + if isinstance(model, PeftModel): + ref_model = None + else: + ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") + logger.info("Created reference model from the model itself.") + else: + ref_model = model # Update arguments training_args_dict = training_args.to_dict() @@ -68,6 +86,11 @@ def run_dpo( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") + if id(model) == id(ref_model): # unable to compute rewards without a reference model + logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.") + remove_keys = [key for key in metrics.keys() if "rewards" in key] + for key in remove_keys: + metrics.pop(key) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics)