support full-parameter PPO
This commit is contained in:
parent
8350bcf85d
commit
ce78303600
|
@ -81,9 +81,9 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||
|
|
|
@ -81,9 +81,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||
|
|
|
@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
|
@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
conversations = []
|
||||
prompt = data["instruction"].strip()
|
||||
response = data["output"].strip()
|
||||
|
||||
|
@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||
human_idx = prompt.rfind("Human:")
|
||||
query = prompt[human_idx+6:assist_idx].strip()
|
||||
prompt = prompt[:human_idx].strip()
|
||||
history = []
|
||||
conversations.insert(0, {"from": "gpt", "value": response})
|
||||
conversations.insert(0, {"from": "human", "value": query})
|
||||
|
||||
while prompt.rfind("Assistant:") != -1:
|
||||
assist_idx = prompt.rfind("Assistant:")
|
||||
|
@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
|||
if human_idx != -1:
|
||||
old_query = prompt[human_idx+6:assist_idx].strip()
|
||||
old_resp = prompt[assist_idx+10:].strip()
|
||||
history.insert(0, (old_query, old_resp))
|
||||
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
||||
conversations.insert(0, {"from": "human", "value": old_query})
|
||||
else:
|
||||
break
|
||||
prompt = prompt[:human_idx].strip()
|
||||
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": response,
|
||||
"history": history
|
||||
}
|
||||
yield key, {"conversations": conversations}
|
||||
|
|
|
@ -88,11 +88,7 @@
|
|||
},
|
||||
"belle_multiturn": {
|
||||
"script_url": "belle_multiturn",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"ultra_chat": {
|
||||
"script_url": "ultra_chat",
|
||||
|
@ -107,6 +103,13 @@
|
|||
"alpaca_cot": {
|
||||
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
||||
},
|
||||
"openorca": {
|
||||
"hf_hub_url": "Open-Orca/OpenOrca",
|
||||
"columns": {
|
||||
"prompt": "question",
|
||||
"response": "response"
|
||||
}
|
||||
},
|
||||
"mathinstruct": {
|
||||
"hf_hub_url": "TIGER-Lab/MathInstruct",
|
||||
"columns": {
|
||||
|
|
|
@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
|||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
yield key, {
|
||||
"conversations": conversations
|
||||
}
|
||||
yield key, {"conversations": conversations}
|
||||
|
|
|
@ -64,6 +64,16 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_current_device() -> str:
|
||||
import accelerate
|
||||
from accelerate import Accelerator
|
||||
dummy_accelerator = Accelerator()
|
||||
if accelerate.utils.is_xpu_available():
|
||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||
else:
|
||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
|
|
|
@ -4,18 +4,10 @@ from dataclasses import asdict, dataclass, field
|
|||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
class FreezeArguments:
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
Arguments pertaining to the freeze (partial-parameter) training.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
|
@ -30,6 +22,13 @@ class FinetuningArguments:
|
|||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
Others choices: the same as LLaMA."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraArguments:
|
||||
r"""
|
||||
Arguments pertaining to the LoRA training.
|
||||
"""
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||
|
@ -60,30 +59,76 @@ class FinetuningArguments:
|
|||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLHFArguments:
|
||||
r"""
|
||||
Arguments pertaining to the PPO and DPO training.
|
||||
"""
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."}
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
ppo_whiten_rewards: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||
)
|
||||
dpo_ref_model: Optional[str] = field(
|
||||
ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||
)
|
||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
||||
ref_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
)
|
||||
ref_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reference model."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
reward_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
||||
)
|
||||
reward_model_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the reward model."}
|
||||
)
|
||||
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
|
@ -92,6 +137,14 @@ class FinetuningArguments:
|
|||
default=0,
|
||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||
)
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
|
@ -103,7 +156,13 @@ class FinetuningArguments:
|
|||
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
||||
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Lora reward model only supports lora training.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
|
|
@ -54,22 +54,10 @@ class ModelArguments:
|
|||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
|
@ -81,8 +69,7 @@ class ModelArguments:
|
|||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from llmtuner.model.loader import load_model_and_tokenizer
|
||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||
from llmtuner.model.utils import dispatch_model, generate_model_card
|
||||
from llmtuner.model.utils import create_ref_model, create_reward_model, dispatch_model, generate_model_card
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
@ -98,30 +95,3 @@ def init_adapter(
|
|||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token,
|
||||
"revision": model_args.model_revision
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
|
|
@ -15,7 +15,6 @@ from transformers import (
|
|||
)
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
|
@ -24,12 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
|||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||
from llmtuner.extras.packages import is_flash_attn2_available
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.model.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.model.utils import prepare_model_for_training
|
||||
from llmtuner.model.adapter import init_adapter
|
||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
@ -74,6 +73,7 @@ def load_model_and_tokenizer(
|
|||
)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
@ -145,7 +145,7 @@ def load_model_and_tokenizer(
|
|||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
# Quantization configurations (using bitsandbytes library)
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
@ -165,10 +165,10 @@ def load_model_and_tokenizer(
|
|||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
# Load and prepare pre-trained models (without valuehead).
|
||||
# Load pre-trained models (without valuehead)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
|
@ -186,7 +186,7 @@ def load_model_and_tokenizer(
|
|||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
# Register auto class to save the custom code files
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||
|
@ -200,25 +200,15 @@ def load_model_and_tokenizer(
|
|||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
if stage in ["rm", "ppo"]:
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||
if vhead_params is not None:
|
||||
model.load_state_dict(vhead_params, strict=False)
|
||||
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
|
|
|
@ -1,17 +1,89 @@
|
|||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def create_ref_model(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
stage: Literal["ppo", "dpo"]
|
||||
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||
r"""
|
||||
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||
|
||||
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||
"""
|
||||
if finetuning_args.ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.ref_model,
|
||||
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
||||
quantization_bit=finetuning_args.ref_model_quantization_bit
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||
else:
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||
logger.info("Created reference model from the model itself.")
|
||||
|
||||
return ref_model
|
||||
|
||||
|
||||
def create_reward_model(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> "AutoModelForCausalLMWithValueHead":
|
||||
r"""
|
||||
Creates reward model for PPO training.
|
||||
"""
|
||||
if finetuning_args.reward_model_type == "lora":
|
||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return None
|
||||
else:
|
||||
reward_model_args_dict = model_args.to_dict()
|
||||
reward_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.reward_model,
|
||||
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
||||
quantization_bit=finetuning_args.reward_model_quantization_bit
|
||||
))
|
||||
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
return reward_model
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
|
@ -41,6 +113,9 @@ def find_all_linear_modules(
|
|||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora.
|
||||
"""
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
|
@ -76,6 +151,32 @@ def generate_model_card(
|
|||
}
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
path_or_repo_id: str,
|
||||
model_args: "ModelArguments"
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||
"""
|
||||
kwargs = {
|
||||
"path_or_repo_id": path_or_repo_id,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||
return None
|
||||
|
||||
return torch.load(vhead_file, map_location="cpu")
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
|
|
|
@ -43,7 +43,11 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False)
|
||||
or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
|||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.model import create_ref_model, generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
|
@ -38,23 +38,10 @@ def run_dpo(
|
|||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.dpo_ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
||||
elif training_args.do_train:
|
||||
if isinstance(model, PeftModel):
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from the model itself.")
|
||||
else:
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
|
@ -80,14 +67,14 @@ def run_dpo(
|
|||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
|
|
|
@ -37,24 +37,44 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
|
||||
self.args = training_args
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
**generating_args.to_dict()
|
||||
)
|
||||
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
if reward_model is not None:
|
||||
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||
self.accelerator.state, "deepspeed_plugin"
|
||||
)
|
||||
if is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
else:
|
||||
self.reward_model = None
|
||||
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
|
@ -213,11 +233,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
r"""
|
||||
Computes scores using given reward model.
|
||||
"""
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
if self.reward_model is None:
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
@ -228,7 +251,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
if self.reward_model is None:
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
|
|
|
@ -9,8 +9,9 @@ from transformers.optimization import get_scheduler
|
|||
|
||||
from llmtuner.data import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
from llmtuner.model import create_ref_model, create_reward_model, load_model_and_tokenizer
|
||||
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -18,6 +19,9 @@ if TYPE_CHECKING:
|
|||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_ppo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
|
@ -33,6 +37,11 @@ def run_ppo(
|
|||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
# Create reference model and reward model
|
||||
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Create ppo config
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
learning_rate=training_args.learning_rate,
|
||||
|
@ -47,9 +56,11 @@ def run_ppo(
|
|||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
|
@ -73,9 +84,10 @@ def run_ppo(
|
|||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
reward_model=reward_model,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
|
@ -88,5 +100,5 @@ def run_ppo(
|
|||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
|
|
@ -42,7 +42,7 @@ def run_pt(
|
|||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
|
|
|
@ -51,7 +51,7 @@ def run_rm(
|
|||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
|
|
|
@ -69,7 +69,7 @@ def run_sft(
|
|||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
|
|
|
@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
|
|||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
tokenizer.save_pretrained(finetuning_args.export_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue