improve rlhf
This commit is contained in:
parent
9dcff3a5b5
commit
c47ab6c072
|
@ -31,31 +31,31 @@ class DatasetAttr:
|
|||
Dataset attributes.
|
||||
"""
|
||||
|
||||
""" basic configs """
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
# extra configs
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" rlhf columns """
|
||||
# rlhf columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
# alpaca columns
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
""" sharegpt columns """
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
""" sharegpt tags """
|
||||
# sharegpt tags
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
|
|
|
@ -509,15 +509,19 @@ register_model_group(
|
|||
},
|
||||
"Gemma-2-9B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
||||
},
|
||||
"Gemma-2-27B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
||||
},
|
||||
"Gemma-2-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||
},
|
||||
"Gemma-2-27B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
|
|
|
@ -27,6 +27,7 @@ from accelerate.utils import DistributedDataParallelKwargs
|
|||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
from transformers.trainer_callback import CallbackHandler
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
@ -105,6 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
|
||||
ppo_config.log_with = None
|
||||
|
||||
# Create optimizer and scheduler
|
||||
if training_args.max_steps > 0:
|
||||
|
@ -143,6 +146,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
@ -339,11 +343,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
|
@ -354,12 +358,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
if len(response_indexes) == 0: # allow empty response
|
||||
response_length = 1
|
||||
elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
|
||||
response_length = response_indexes[-1].item() + 2
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
response_length = response_indexes[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
@ -382,7 +388,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
|
@ -392,7 +398,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
@ -400,13 +406,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
if self.is_chatglm_model: # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
return rewards
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.to(torch.float32).detach().cpu() # use fp32 type
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
|
@ -440,7 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with self.amp_context: # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
|
||||
|
||||
if self.is_chatglm_model:
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
|
|
@ -12,11 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
if TYPE_CHECKING:
|
||||
from transformers import EvalPrediction
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,28 +14,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import json
|
||||
import os
|
||||
|
@ -53,6 +31,7 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
|||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
|
@ -108,46 +87,23 @@ class PairwiseTrainer(Trainer):
|
|||
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
end_index = chosen_length
|
||||
div_index = end_index - 1
|
||||
else:
|
||||
end_index = max(chosen_length, rejected_length)
|
||||
div_index = check_divergence[0]
|
||||
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
|
||||
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
|
||||
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores - rejected_scores).mean()
|
||||
if return_outputs:
|
||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||
return loss, [loss, chosen_scores, rejected_scores]
|
||||
|
||||
return loss
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
else:
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,28 +14,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Dict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import EvalPrediction
|
||||
from transformers.utils import is_jieba_available, is_nltk_available
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
|
@ -29,7 +28,7 @@ from ...extras.packages import is_rouge_available
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import EvalPrediction, PreTrainedTokenizer
|
||||
|
||||
|
||||
if is_jieba_available():
|
||||
|
|
|
@ -57,7 +57,7 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
|||
elif finetuning_args.stage == "kto":
|
||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
raise ValueError("Unknown task: {}.".format(finetuning_args.stage))
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
|
Loading…
Reference in New Issue