improve rlhf

This commit is contained in:
hiyouga 2024-07-02 22:23:08 +08:00
parent 9dcff3a5b5
commit c47ab6c072
8 changed files with 55 additions and 114 deletions

View File

@ -31,31 +31,31 @@ class DatasetAttr:
Dataset attributes.
"""
""" basic configs """
# basic configs
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """
# extra configs
subset: Optional[str] = None
folder: Optional[str] = None
num_samples: Optional[int] = None
""" common columns """
# common columns
system: Optional[str] = None
tools: Optional[str] = None
images: Optional[str] = None
""" rlhf columns """
# rlhf columns
chosen: Optional[str] = None
rejected: Optional[str] = None
kto_tag: Optional[str] = None
""" alpaca columns """
# alpaca columns
prompt: Optional[str] = "instruction"
query: Optional[str] = "input"
response: Optional[str] = "output"
history: Optional[str] = None
""" sharegpt columns """
# sharegpt columns
messages: Optional[str] = "conversations"
""" sharegpt tags """
# sharegpt tags
role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human"

View File

@ -509,15 +509,19 @@ register_model_group(
},
"Gemma-2-9B": {
DownloadSource.DEFAULT: "google/gemma-2-9b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
},
"Gemma-2-27B": {
DownloadSource.DEFAULT: "google/gemma-2-27b",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
},
"Gemma-2-9B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
},
"Gemma-2-27B-Chat": {
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
},
},
template="gemma",

View File

@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
from transformers.optimization import get_scheduler
from transformers.trainer import DEFAULT_CALLBACKS
from transformers.trainer_callback import CallbackHandler
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
@ -105,6 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
ppo_config.log_with = None
# Create optimizer and scheduler
if training_args.max_steps > 0:
@ -143,6 +146,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.control = TrainerControl()
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
)
@ -339,11 +343,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
batch[k] = v[:, start_index:]
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(unwrapped_model)
generate_output: torch.Tensor = unwrapped_model.generate(
generate_output: "torch.Tensor" = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
)
if self.model_args.upcast_layernorm:
@ -354,12 +358,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries, responses = [], []
for i in range(len(query)):
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
if len(response_indexes) == 0: # allow empty response
response_length = 1
elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
response_length = response_indexes[-1].item() + 2
else:
response_length = response_index[-1].item() + 1
response_length = response_indexes[-1].item() + 1
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@ -382,7 +388,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
return get_rewards_from_server(self.reward_model, messages)
batch = self.prepare_model_inputs(queries, responses)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if self.finetuning_args.reward_model_type == "lora":
@ -392,7 +398,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default")
@ -400,13 +406,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.is_chatglm_model: # assume same architecture
values = torch.transpose(values, 0, 1)
rewards = []
for i in range(values.size(0)):
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
return rewards
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.to(torch.float32).detach().cpu() # use fp32 type
@PPODecorators.empty_device_cache()
def batched_forward_pass(
@ -440,7 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
attention_mask = input_kwargs["attention_mask"]
with self.amp_context: # support bf16
logits, _, values = model(**input_kwargs)
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
if self.is_chatglm_model:
values = torch.transpose(values, 0, 1)

View File

@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Dict
import numpy as np
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
if TYPE_CHECKING:
from transformers import EvalPrediction
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}

View File

@ -1,7 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,28 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2022 CarperAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
import os
@ -53,6 +31,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import PreTrainedModel, ProcessorMixin
from transformers.trainer import PredictionOutput
from trl import AutoModelForCausalLMWithValueHead
from ...hparams import FinetuningArguments
@ -108,46 +87,23 @@ class PairwiseTrainer(Trainer):
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
"""
# Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
# Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
loss = 0
for i in range(batch_size):
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean()
if return_outputs:
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
return loss, (loss, chosen_scores, rejected_scores)
else:
return loss
def save_predictions(self, predict_results: "PredictionOutput") -> None:
r"""

View File

@ -1,7 +1,7 @@
# Copyright 2024 the LlamaFactory team.
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,28 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2022 CarperAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import TYPE_CHECKING, List, Optional

View File

@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Dict
import numpy as np
import torch
from transformers import EvalPrediction
from transformers.utils import is_jieba_available, is_nltk_available
from ...extras.constants import IGNORE_INDEX
@ -29,7 +28,7 @@ from ...extras.packages import is_rouge_available
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from transformers import EvalPrediction, PreTrainedTokenizer
if is_jieba_available():

View File

@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
elif finetuning_args.stage == "kto":
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
def export_model(args: Optional[Dict[str, Any]] = None) -> None: