forked from p04798526/LLaMA-Factory-Mirror
support full-parameter PPO
This commit is contained in:
parent
8350bcf85d
commit
ce78303600
|
@ -81,9 +81,9 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
| Reward Modeling | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
| PPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
| DPO Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||||
|
|
|
@ -81,9 +81,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
| 奖励模型训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
| PPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
| DPO 训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||||
|
|
|
@ -24,9 +24,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
|
|
||||||
def _info(self):
|
def _info(self):
|
||||||
features = datasets.Features({
|
features = datasets.Features({
|
||||||
"instruction": datasets.Value("string"),
|
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||||
"output": datasets.Value("string"),
|
|
||||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
|
||||||
})
|
})
|
||||||
return datasets.DatasetInfo(
|
return datasets.DatasetInfo(
|
||||||
description=_DESCRIPTION,
|
description=_DESCRIPTION,
|
||||||
|
@ -51,6 +49,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
with open(filepath, "r", encoding="utf-8") as f:
|
||||||
for key, row in enumerate(f):
|
for key, row in enumerate(f):
|
||||||
data = json.loads(row)
|
data = json.loads(row)
|
||||||
|
conversations = []
|
||||||
prompt = data["instruction"].strip()
|
prompt = data["instruction"].strip()
|
||||||
response = data["output"].strip()
|
response = data["output"].strip()
|
||||||
|
|
||||||
|
@ -58,7 +57,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
human_idx = prompt.rfind("Human:")
|
human_idx = prompt.rfind("Human:")
|
||||||
query = prompt[human_idx+6:assist_idx].strip()
|
query = prompt[human_idx+6:assist_idx].strip()
|
||||||
prompt = prompt[:human_idx].strip()
|
prompt = prompt[:human_idx].strip()
|
||||||
history = []
|
conversations.insert(0, {"from": "gpt", "value": response})
|
||||||
|
conversations.insert(0, {"from": "human", "value": query})
|
||||||
|
|
||||||
while prompt.rfind("Assistant:") != -1:
|
while prompt.rfind("Assistant:") != -1:
|
||||||
assist_idx = prompt.rfind("Assistant:")
|
assist_idx = prompt.rfind("Assistant:")
|
||||||
|
@ -66,13 +66,10 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||||
if human_idx != -1:
|
if human_idx != -1:
|
||||||
old_query = prompt[human_idx+6:assist_idx].strip()
|
old_query = prompt[human_idx+6:assist_idx].strip()
|
||||||
old_resp = prompt[assist_idx+10:].strip()
|
old_resp = prompt[assist_idx+10:].strip()
|
||||||
history.insert(0, (old_query, old_resp))
|
conversations.insert(0, {"from": "gpt", "value": old_resp})
|
||||||
|
conversations.insert(0, {"from": "human", "value": old_query})
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
prompt = prompt[:human_idx].strip()
|
prompt = prompt[:human_idx].strip()
|
||||||
|
|
||||||
yield key, {
|
yield key, {"conversations": conversations}
|
||||||
"instruction": query,
|
|
||||||
"output": response,
|
|
||||||
"history": history
|
|
||||||
}
|
|
||||||
|
|
|
@ -88,11 +88,7 @@
|
||||||
},
|
},
|
||||||
"belle_multiturn": {
|
"belle_multiturn": {
|
||||||
"script_url": "belle_multiturn",
|
"script_url": "belle_multiturn",
|
||||||
"columns": {
|
"formatting": "sharegpt"
|
||||||
"prompt": "instruction",
|
|
||||||
"response": "output",
|
|
||||||
"history": "history"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"ultra_chat": {
|
"ultra_chat": {
|
||||||
"script_url": "ultra_chat",
|
"script_url": "ultra_chat",
|
||||||
|
@ -107,6 +103,13 @@
|
||||||
"alpaca_cot": {
|
"alpaca_cot": {
|
||||||
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
||||||
},
|
},
|
||||||
|
"openorca": {
|
||||||
|
"hf_hub_url": "Open-Orca/OpenOrca",
|
||||||
|
"columns": {
|
||||||
|
"prompt": "question",
|
||||||
|
"response": "response"
|
||||||
|
}
|
||||||
|
},
|
||||||
"mathinstruct": {
|
"mathinstruct": {
|
||||||
"hf_hub_url": "TIGER-Lab/MathInstruct",
|
"hf_hub_url": "TIGER-Lab/MathInstruct",
|
||||||
"columns": {
|
"columns": {
|
||||||
|
|
|
@ -66,6 +66,4 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||||
"from": "human" if i % 2 == 0 else "gpt",
|
"from": "human" if i % 2 == 0 else "gpt",
|
||||||
"value": content[i]
|
"value": content[i]
|
||||||
} for i in range(len(content))]
|
} for i in range(len(content))]
|
||||||
yield key, {
|
yield key, {"conversations": conversations}
|
||||||
"conversations": conversations
|
|
||||||
}
|
|
||||||
|
|
|
@ -64,6 +64,16 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_device() -> str:
|
||||||
|
import accelerate
|
||||||
|
from accelerate import Accelerator
|
||||||
|
dummy_accelerator = Accelerator()
|
||||||
|
if accelerate.utils.is_xpu_available():
|
||||||
|
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
||||||
|
else:
|
||||||
|
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
r"""
|
r"""
|
||||||
Gets logits processor that removes NaN and Inf logits.
|
Gets logits processor that removes NaN and Inf logits.
|
||||||
|
|
|
@ -4,18 +4,10 @@ from dataclasses import asdict, dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments:
|
class FreezeArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to the freeze (partial-parameter) training.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
|
||||||
default="sft",
|
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
|
||||||
)
|
|
||||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
|
||||||
default="lora",
|
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
|
||||||
)
|
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||||
|
@ -30,6 +22,13 @@ class FinetuningArguments:
|
||||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||||
Others choices: the same as LLaMA."}
|
Others choices: the same as LLaMA."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoraArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the LoRA training.
|
||||||
|
"""
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
||||||
|
@ -60,30 +59,76 @@ class FinetuningArguments:
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
)
|
)
|
||||||
ppo_score_norm: Optional[bool] = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use score normalization in PPO training."}
|
@dataclass
|
||||||
|
class RLHFArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the PPO and DPO training.
|
||||||
|
"""
|
||||||
|
dpo_beta: Optional[float] = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
ppo_logger: Optional[str] = field(
|
ppo_logger: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||||
)
|
)
|
||||||
|
ppo_score_norm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use score normalization in PPO training."}
|
||||||
|
)
|
||||||
ppo_target: Optional[float] = field(
|
ppo_target: Optional[float] = field(
|
||||||
default=6.0,
|
default=6.0,
|
||||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||||
)
|
)
|
||||||
dpo_beta: Optional[float] = field(
|
ppo_whiten_rewards: Optional[bool] = field(
|
||||||
default=0.1,
|
default=False,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
|
||||||
)
|
)
|
||||||
dpo_ref_model: Optional[str] = field(
|
ref_model: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."}
|
||||||
)
|
)
|
||||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
ref_model_checkpoint: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||||
)
|
)
|
||||||
|
ref_model_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the reference model."}
|
||||||
|
)
|
||||||
|
reward_model: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_checkpoint: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_quantization_bit: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The number of bits to quantize the reward model."}
|
||||||
|
)
|
||||||
|
reward_model_type: Optional[Literal["lora", "full"]] = field(
|
||||||
|
default="lora",
|
||||||
|
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
|
"""
|
||||||
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
|
default="sft",
|
||||||
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
|
)
|
||||||
|
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||||
|
default="lora",
|
||||||
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
|
)
|
||||||
upcast_layernorm: Optional[bool] = field(
|
upcast_layernorm: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||||
|
@ -92,6 +137,14 @@ class FinetuningArguments:
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||||
)
|
)
|
||||||
|
export_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Path to the directory to save the exported model."}
|
||||||
|
)
|
||||||
|
plot_loss: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
def split_arg(arg):
|
def split_arg(arg):
|
||||||
|
@ -103,7 +156,13 @@ class FinetuningArguments:
|
||||||
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
self.lora_alpha = self.lora_alpha or float(self.lora_rank * 2.0)
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
|
self.ref_model_checkpoint = split_arg(self.ref_model_checkpoint)
|
||||||
|
self.reward_model_checkpoint = split_arg(self.reward_model_checkpoint)
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
if self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||||
|
raise ValueError("Lora reward model only supports lora training.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
|
|
|
@ -54,22 +54,10 @@ class ModelArguments:
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
|
||||||
)
|
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
)
|
)
|
||||||
export_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Path to the directory to save the exported model."}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
@ -81,8 +69,7 @@ class ModelArguments:
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
if self.quantization_bit is not None:
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from llmtuner.model.loader import load_model_and_tokenizer
|
from llmtuner.model.loader import load_model_and_tokenizer
|
||||||
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args
|
||||||
from llmtuner.model.utils import dispatch_model, generate_model_card
|
from llmtuner.model.utils import create_ref_model, create_reward_model, dispatch_model, generate_model_card
|
||||||
|
|
|
@ -1,8 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers.utils import cached_file
|
|
||||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|
||||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
@ -98,30 +95,3 @@ def init_adapter(
|
||||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(
|
|
||||||
model: "PreTrainedModel",
|
|
||||||
model_args: "ModelArguments"
|
|
||||||
) -> bool:
|
|
||||||
kwargs = {
|
|
||||||
"path_or_repo_id": model_args.reward_model,
|
|
||||||
"cache_dir": model_args.cache_dir,
|
|
||||||
"token": model_args.hf_hub_token,
|
|
||||||
"revision": model_args.model_revision
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
|
||||||
except:
|
|
||||||
try:
|
|
||||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
|
||||||
except:
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
|
||||||
return False
|
|
||||||
|
|
||||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
|
||||||
return True
|
|
||||||
|
|
|
@ -15,7 +15,6 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.models.llama import modeling_llama as LlamaModule
|
from transformers.models.llama import modeling_llama as LlamaModule
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from peft import PeftModel
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -24,12 +23,12 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
from llmtuner.extras.misc import count_parameters, get_current_device, infer_optim_dtype
|
||||||
from llmtuner.extras.packages import is_flash_attn2_available
|
from llmtuner.extras.packages import is_flash_attn2_available
|
||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.model.adapter import init_adapter, load_valuehead_params
|
from llmtuner.model.adapter import init_adapter
|
||||||
from llmtuner.model.utils import prepare_model_for_training
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
@ -74,6 +73,7 @@ def load_model_and_tokenizer(
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||||
|
logger.info("Use `model_name_or_path` to specify the model trained with full/freeze method.")
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
else:
|
else:
|
||||||
model_to_load = model_args.model_name_or_path
|
model_to_load = model_args.model_name_or_path
|
||||||
|
@ -145,7 +145,7 @@ def load_model_and_tokenizer(
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning("Current model does not support shift short attention.")
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library)
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
@ -165,10 +165,10 @@ def load_model_and_tokenizer(
|
||||||
bnb_4bit_quant_type=model_args.quantization_type
|
bnb_4bit_quant_type=model_args.quantization_type
|
||||||
)
|
)
|
||||||
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
# Load and prepare pre-trained models (without valuehead).
|
# Load pre-trained models (without valuehead)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -186,7 +186,7 @@ def load_model_and_tokenizer(
|
||||||
setattr(model, "lm_head", model.transformer.output_layer)
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files
|
||||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
|
||||||
|
@ -200,25 +200,15 @@ def load_model_and_tokenizer(
|
||||||
model = model.train() if is_trainable else model.eval()
|
model = model.train() if is_trainable else model.eval()
|
||||||
|
|
||||||
# Prepare model with valuehead for RLHF
|
# Prepare model with valuehead for RLHF
|
||||||
if stage == "rm" or stage == "ppo":
|
if stage in ["rm", "ppo"]:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
reset_logging()
|
reset_logging()
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
if model_args.checkpoint_dir is not None: # load valuehead weights if exists
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||||
if load_valuehead_params(model, model_args):
|
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||||
model.v_head.load_state_dict({
|
if vhead_params is not None:
|
||||||
"summary.weight": getattr(model, "reward_head_weight"),
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
"summary.bias": getattr(model, "reward_head_bias")
|
logger.info("Loaded valuehead from checkpoint: {}".format(model_args.checkpoint_dir[-1]))
|
||||||
})
|
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
|
||||||
if isinstance(model.pretrained_model, PeftModel):
|
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
|
||||||
if "default" in name:
|
|
||||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
|
||||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
|
||||||
|
|
||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
|
|
|
@ -1,17 +1,89 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||||
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
stage: Literal["ppo", "dpo"]
|
||||||
|
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
|
||||||
|
r"""
|
||||||
|
Creates reference model for PPO/DPO training. Evaluation mode is not supported.
|
||||||
|
|
||||||
|
The valuehead parameter is randomly initialized since it is useless for PPO training.
|
||||||
|
"""
|
||||||
|
if finetuning_args.ref_model is not None:
|
||||||
|
ref_model_args_dict = model_args.to_dict()
|
||||||
|
ref_model_args_dict.update(dict(
|
||||||
|
model_name_or_path=finetuning_args.ref_model,
|
||||||
|
checkpoint_dir=finetuning_args.ref_model_checkpoint,
|
||||||
|
quantization_bit=finetuning_args.ref_model_quantization_bit
|
||||||
|
))
|
||||||
|
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||||
|
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
|
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage)
|
||||||
|
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
|
||||||
|
else:
|
||||||
|
if finetuning_args.finetuning_type == "lora":
|
||||||
|
ref_model = None
|
||||||
|
else:
|
||||||
|
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage)
|
||||||
|
logger.info("Created reference model from the model itself.")
|
||||||
|
|
||||||
|
return ref_model
|
||||||
|
|
||||||
|
|
||||||
|
def create_reward_model(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
finetuning_args: "FinetuningArguments"
|
||||||
|
) -> "AutoModelForCausalLMWithValueHead":
|
||||||
|
r"""
|
||||||
|
Creates reward model for PPO training.
|
||||||
|
"""
|
||||||
|
if finetuning_args.reward_model_type == "lora":
|
||||||
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
|
if "default" in name:
|
||||||
|
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||||
|
vhead_params = load_valuehead_params(model_args.checkpoint_dir[-1], model_args)
|
||||||
|
assert vhead_params is not None, "Reward model is not correctly loaded."
|
||||||
|
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||||
|
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||||
|
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||||
|
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||||
|
logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
reward_model_args_dict = model_args.to_dict()
|
||||||
|
reward_model_args_dict.update(dict(
|
||||||
|
model_name_or_path=finetuning_args.reward_model,
|
||||||
|
checkpoint_dir=finetuning_args.reward_model_checkpoint,
|
||||||
|
quantization_bit=finetuning_args.reward_model_quantization_bit
|
||||||
|
))
|
||||||
|
reward_model_args = ModelArguments(**reward_model_args_dict)
|
||||||
|
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
|
||||||
|
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo")
|
||||||
|
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
r"""
|
r"""
|
||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
|
@ -41,6 +113,9 @@ def find_all_linear_modules(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
quantization_bit: Optional[int] = None
|
quantization_bit: Optional[int] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
r"""
|
||||||
|
Finds all available modules to apply lora.
|
||||||
|
"""
|
||||||
if quantization_bit is not None:
|
if quantization_bit is not None:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||||
|
@ -76,6 +151,32 @@ def generate_model_card(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_valuehead_params(
|
||||||
|
path_or_repo_id: str,
|
||||||
|
model_args: "ModelArguments"
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Loads value head parameters from Hugging Face Hub or local disk.
|
||||||
|
|
||||||
|
Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
|
||||||
|
"""
|
||||||
|
kwargs = {
|
||||||
|
"path_or_repo_id": path_or_repo_id,
|
||||||
|
"cache_dir": model_args.cache_dir,
|
||||||
|
"token": model_args.hf_hub_token
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||||
|
except:
|
||||||
|
logger.warning("Provided path ({}) does not contain valuehead weights.".format(path_or_repo_id))
|
||||||
|
return None
|
||||||
|
|
||||||
|
return torch.load(vhead_file, map_location="cpu")
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_for_training(
|
def prepare_model_for_training(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
|
|
@ -43,7 +43,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
|
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
if self.is_deepspeed_enabled:
|
if self.is_deepspeed_enabled:
|
||||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
if not (
|
||||||
|
getattr(ref_model, "is_loaded_in_8bit", False)
|
||||||
|
or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
from llmtuner.model import generate_model_card, load_model_and_tokenizer
|
from llmtuner.model import create_ref_model, generate_model_card, load_model_and_tokenizer
|
||||||
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
|
||||||
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
from llmtuner.train.dpo.trainer import CustomDPOTrainer
|
||||||
|
|
||||||
|
@ -38,23 +38,10 @@ def run_dpo(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create reference model
|
# Create reference model
|
||||||
if finetuning_args.dpo_ref_model is not None:
|
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||||
ref_model_args_dict = model_args.to_dict()
|
|
||||||
ref_model_args_dict.update(dict(
|
|
||||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
|
||||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
|
||||||
))
|
|
||||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
|
||||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
|
||||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
|
||||||
elif training_args.do_train:
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
ref_model = None
|
|
||||||
else:
|
|
||||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
|
||||||
logger.info("Created reference model from the model itself.")
|
|
||||||
else:
|
|
||||||
ref_model = model
|
ref_model = model
|
||||||
|
else:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo")
|
||||||
|
|
||||||
# Update arguments
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
|
@ -80,14 +67,14 @@ def run_dpo(
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||||
for key in remove_keys:
|
for key in remove_keys:
|
||||||
metrics.pop(key)
|
metrics.pop(key)
|
||||||
|
|
|
@ -37,24 +37,44 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["TrainerCallback"],
|
callbacks: List["TrainerCallback"],
|
||||||
|
reward_model: "AutoModelForCausalLMWithValueHead",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
|
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
|
||||||
self.generation_config = GenerationConfig(
|
self.generation_config = GenerationConfig(
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
**generating_args.to_dict()
|
**generating_args.to_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
|
if reward_model is not None:
|
||||||
|
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||||
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
|
)
|
||||||
|
if is_deepspeed_enabled:
|
||||||
|
if not (
|
||||||
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.reward_model = self._prepare_deepspeed(self.reward_model)
|
||||||
|
else:
|
||||||
|
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||||
|
else:
|
||||||
|
self.reward_model = None
|
||||||
|
|
||||||
def ppo_train(self) -> None:
|
def ppo_train(self) -> None:
|
||||||
r"""
|
r"""
|
||||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||||
|
@ -213,11 +233,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
"""
|
"""
|
||||||
replace_model(unwrapped_model, target="reward")
|
if self.reward_model is None:
|
||||||
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
reward_model = self.reward_model if self.reward_model is not None else self.model
|
||||||
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
@ -228,7 +251,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
replace_model(unwrapped_model, target="default")
|
if self.reward_model is None:
|
||||||
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
@PPODecorators.empty_device_cache()
|
@PPODecorators.empty_device_cache()
|
||||||
|
|
|
@ -9,8 +9,9 @@ from transformers.optimization import get_scheduler
|
||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import create_ref_model, create_reward_model, load_model_and_tokenizer
|
||||||
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
from llmtuner.train.ppo.trainer import CustomPPOTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -18,6 +19,9 @@ if TYPE_CHECKING:
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -33,6 +37,11 @@ def run_ppo(
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||||
|
|
||||||
|
# Create reference model and reward model
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
|
||||||
|
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||||
|
|
||||||
|
# Create ppo config
|
||||||
ppo_config = PPOConfig(
|
ppo_config = PPOConfig(
|
||||||
model_name=model_args.model_name_or_path,
|
model_name=model_args.model_name_or_path,
|
||||||
learning_rate=training_args.learning_rate,
|
learning_rate=training_args.learning_rate,
|
||||||
|
@ -47,9 +56,11 @@ def run_ppo(
|
||||||
log_with=finetuning_args.ppo_logger,
|
log_with=finetuning_args.ppo_logger,
|
||||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create optimizer and scheduler
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
if training_args.max_steps > 0:
|
if training_args.max_steps > 0:
|
||||||
num_training_steps = training_args.max_steps
|
num_training_steps = training_args.max_steps
|
||||||
|
@ -73,9 +84,10 @@ def run_ppo(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks + [SavePeftModelCallback()],
|
callbacks=callbacks + [SavePeftModelCallback()],
|
||||||
|
reward_model=reward_model,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=ref_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
@ -88,5 +100,5 @@ def run_ppo(
|
||||||
ppo_trainer.ppo_train()
|
ppo_trainer.ppo_train()
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|
|
@ -42,7 +42,7 @@ def run_pt(
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|
|
@ -51,7 +51,7 @@ def run_rm(
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|
|
@ -69,7 +69,7 @@ def run_sft(
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|
|
@ -38,11 +38,11 @@ def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
model.config.use_cache = True
|
model.config.use_cache = True
|
||||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
model.save_pretrained(finetuning_args.export_dir, max_shard_size=max_shard_size)
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
tokenizer.save_pretrained(model_args.export_dir)
|
tokenizer.save_pretrained(finetuning_args.export_dir)
|
||||||
except:
|
except:
|
||||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue