add kto
This commit is contained in:
parent
84415492bf
commit
db1d5a4f51
|
@ -32,6 +32,15 @@
|
|||
"history": "history"
|
||||
}
|
||||
},
|
||||
"kto-mix-test": {
|
||||
"file_name": "kto-mix-test.json",
|
||||
"file_sha1": "91b59f657007dc4b17529fc643v9b9cd6d640fha",
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"response": "output",
|
||||
"tag": "tag"
|
||||
}
|
||||
},
|
||||
"glaive_toolcall": {
|
||||
"file_name": "glaive_toolcall_10k.json",
|
||||
"formatting": "sharegpt",
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,4 +1,4 @@
|
|||
from .collator import PairwiseDataCollatorWithPadding
|
||||
from .collator import PairwiseDataCollatorWithPadding,KTODataCollatorWithPadding
|
||||
from .loader import get_dataset
|
||||
from .template import Template, get_template_and_fix_tokenizer, templates
|
||||
from .utils import Role, split_dataset
|
||||
|
@ -6,6 +6,7 @@ from .utils import Role, split_dataset
|
|||
|
||||
__all__ = [
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"KTODataCollatorWithPadding",
|
||||
"get_dataset",
|
||||
"Template",
|
||||
"get_template_and_fix_tokenizer",
|
||||
|
|
|
@ -29,7 +29,7 @@ def _convert_images(images: List[Any], dataset_attr: "DatasetAttr", data_args: "
|
|||
def convert_alpaca(
|
||||
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
||||
) -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": [], "tag": []}
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
for i in range(len(examples[dataset_attr.prompt])):
|
||||
prompt = []
|
||||
|
@ -61,6 +61,7 @@ def convert_alpaca(
|
|||
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||
outputs["tools"].append("")
|
||||
outputs["images"].append(convert_images(examples[dataset_attr.images][i]) if dataset_attr.images else [])
|
||||
outputs["tag"].append(examples[dataset_attr.tag][i] if dataset_attr.tag else True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -137,6 +138,7 @@ def align_dataset(
|
|||
"system": {"dtype": "string", "_type": "Value"},
|
||||
"tools": {"dtype": "string", "_type": "Value"},
|
||||
"images": [{"_type": "Image"}],
|
||||
"tag": {"dtype": "bool", "_type": "Value"},
|
||||
}
|
||||
)
|
||||
kwargs = {}
|
||||
|
|
|
@ -49,3 +49,36 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||
batch = super().__call__(concatenated_features)
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
|
||||
@dataclass
|
||||
class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
def __call__(self, features, return_tensors=None):
|
||||
concatenated_features = []
|
||||
kl_concatenated_features = []
|
||||
tags = []
|
||||
for feature in features:
|
||||
concatenated_features.append(
|
||||
{
|
||||
"input_ids": feature["input_ids"],
|
||||
"attention_mask": feature["attention_mask"],
|
||||
"labels": feature["labels"],
|
||||
}
|
||||
)
|
||||
kl_concatenated_features.append(
|
||||
{
|
||||
"input_ids": feature["kl_input_ids"],
|
||||
"attention_mask": feature["kl_attention_mask"],
|
||||
"labels": feature["kl_labels"],
|
||||
}
|
||||
)
|
||||
tags.append(feature["tag"])
|
||||
batch = super().__call__(concatenated_features)
|
||||
kl_batch = super().__call__(kl_concatenated_features)
|
||||
batch["KL_completion_input_ids"] = kl_batch["input_ids"]
|
||||
batch["KL_completion_attention_mask"] = kl_batch["attention_mask"]
|
||||
batch["kl_labels"] = kl_batch["labels"]
|
||||
batch["tag"] = torch.tensor(tags)
|
||||
return batch
|
|
@ -116,7 +116,7 @@ def get_dataset(
|
|||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
|
|
|
@ -28,6 +28,7 @@ class DatasetAttr:
|
|||
""" columns """
|
||||
system: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
tag: Optional[bool] = None
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
|
@ -106,7 +107,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system", "images"]
|
||||
column_names = ["system", "images", "tag"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
|
|
|
@ -70,7 +70,7 @@ def preprocess_supervised_dataset(
|
|||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [], "tag": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
@ -111,11 +111,102 @@ def preprocess_supervised_dataset(
|
|||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["tag"].append(examples["tag"])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_kto_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"kl_input_ids": [], "kl_attention_mask": [], "kl_labels": [], "tag": []}
|
||||
"""Creates mismatched pairs of prompts and completions for the KL dataset by reversing the order of completions."""
|
||||
examples['kl_response'] = examples['response'][::-1]
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
preprocess_visual_inputs = partial(_preprocess_visual_inputs, processor=processor)
|
||||
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
continue
|
||||
|
||||
if processor is not None:
|
||||
examples["prompt"][i][0]["content"] = "<image>" + examples["prompt"][i][0]["content"]
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
kl_messages = examples["prompt"][i] + examples["kl_response"][i]
|
||||
input_ids, labels = [], []
|
||||
kl_input_ids, kl_labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(
|
||||
tokenizer,
|
||||
kl_messages,
|
||||
examples["system"][i],
|
||||
examples["tools"][i],
|
||||
data_args.cutoff_len,
|
||||
data_args.reserved_label_len,
|
||||
)
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
kl_input_ids += source_ids + target_ids
|
||||
kl_labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
kl_input_ids += [tokenizer.eos_token_id]
|
||||
kl_labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
model_inputs["kl_input_ids"].append(kl_input_ids)
|
||||
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
|
||||
model_inputs["kl_labels"].append(kl_labels)
|
||||
model_inputs["tag"].append(examples["tag"][i])
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"].append(preprocess_visual_inputs(examples["images"][i]))
|
||||
desirable = sum([1 for tag in model_inputs["tag"] if tag is True])
|
||||
undesirable = sum([1 for tag in model_inputs["tag"] if tag is False])
|
||||
logger.info("desirable data in KTO dataset: {},undesirable data in KTO dataset: {}".format(desirable, undesirable))
|
||||
if desirable == 0 or undesirable == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
return model_inputs
|
||||
|
||||
def preprocess_packed_supervised_dataset(
|
||||
examples: Dict[str, List[Any]],
|
||||
|
@ -289,7 +380,7 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
|||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
|
@ -328,6 +419,15 @@ def get_preprocess_and_print_func(
|
|||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "kto":
|
||||
preprocess_func = partial(
|
||||
preprocess_kto_dataset,
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||
else:
|
||||
preprocess_func = partial(
|
||||
preprocess_unsupervised_dataset,
|
||||
|
|
|
@ -45,6 +45,7 @@ TRAINING_STAGES = {
|
|||
"Reward Modeling": "rm",
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"KTO": "kto",
|
||||
"ORPO": "orpo",
|
||||
"Pre-Training": "pt",
|
||||
}
|
||||
|
|
|
@ -133,6 +133,22 @@ class RLHFArguments:
|
|||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
||||
)
|
||||
kto_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the KTO loss."},
|
||||
)
|
||||
kto_ftx: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in KTO training."},
|
||||
)
|
||||
kto_desirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The desirable weight for the KTO loss."},
|
||||
)
|
||||
kto_undesirable_weight: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "The undesirable weight for the KTO loss."},
|
||||
)
|
||||
orpo_beta: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta (lambda) parameter in ORPO loss representing the weight of the SFT loss."},
|
||||
|
@ -291,7 +307,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
default=False,
|
||||
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
||||
)
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "kto"] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."},
|
||||
)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .workflow import run_kto
|
||||
|
||||
|
||||
__all__ = ["run_kto"]
|
|
@ -0,0 +1,206 @@
|
|||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
|
||||
class CustomKTOTrainer(KTOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
|
||||
finetuning_args: "FinetuningArguments",
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
self.precompute_ref_log_probs = False
|
||||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
self.ref_model = ref_model
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# KTO parameter
|
||||
self.beta = finetuning_args.kto_beta
|
||||
self.ftx_gamma = finetuning_args.kto_ftx
|
||||
self.desirable_weight = finetuning_args.kto_desirable_weight
|
||||
self.undesirable_weight = finetuning_args.kto_undesirable_weight
|
||||
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||
): # quantized models are already set on the correct device
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def sft_loss(self, chosen_logits: "torch.FloatTensor", chosen_labels: "torch.LongTensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self.get_batch_logps(chosen_logits, chosen_labels, average_log_prob=True)
|
||||
return -all_logps.nanmean()
|
||||
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
with torch.no_grad():
|
||||
KL_logits = model(
|
||||
batch["KL_completion_input_ids"],
|
||||
attention_mask=batch["KL_completion_attention_mask"],
|
||||
).logits
|
||||
|
||||
completion_logits = model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
).logits
|
||||
|
||||
completion_logps = self.get_batch_logps(
|
||||
completion_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
KL_logps = self.get_batch_logps(
|
||||
KL_logits,
|
||||
batch["kl_labels"],
|
||||
average_log_prob=False,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
if completion_logps.shape[0] != len(batch["tag"]):
|
||||
raise ValueError(
|
||||
"There is a mismatch between the number of examples in this batch and the number of "
|
||||
"examples for which an output sequence was predicted."
|
||||
)
|
||||
chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["tag"][i]]
|
||||
rejected_idx = [i for i in range(completion_logps.shape[0]) if not batch["tag"][i]]
|
||||
|
||||
chosen_logps = completion_logps[chosen_idx, ...]
|
||||
rejected_logps = completion_logps[rejected_idx, ...]
|
||||
|
||||
chosen_logits = completion_logits[chosen_idx, ...]
|
||||
rejected_logits = completion_logits[rejected_idx, ...]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps)
|
||||
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
):
|
||||
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_KL_logps,
|
||||
) = self.forward(model, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
ref_model = self.model
|
||||
ref_context = self.accelerator.unwrap_model(self.model).disable_adapter()
|
||||
else:
|
||||
ref_model = self.ref_model
|
||||
ref_context = nullcontext()
|
||||
with ref_context:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
reference_KL_logps,
|
||||
) = self.forward(ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_KL_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
reference_KL_logps,
|
||||
)
|
||||
losses = losses.nanmean()
|
||||
if self.ftx_gamma > 1e-6 and len(batch["labels"][batch['tag']])>0:
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, batch["labels"][batch['tag']])
|
||||
|
||||
|
||||
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
|
||||
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)
|
||||
|
||||
all_num_chosen = self.accelerator.gather(num_chosen).sum().item()
|
||||
all_num_rejected = self.accelerator.gather(num_rejected).sum().item()
|
||||
|
||||
if all_num_chosen > 0:
|
||||
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item()
|
||||
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
|
||||
metrics["count/chosen"] = all_num_chosen
|
||||
|
||||
if all_num_rejected > 0:
|
||||
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
|
||||
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
|
||||
metrics["count/rejected"] = all_num_rejected
|
||||
|
||||
metrics["kl"] = kl.item()
|
||||
|
||||
return losses, metrics
|
|
@ -0,0 +1,78 @@
|
|||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..utils import create_modelcard_and_push, create_ref_model
|
||||
from .trainer import CustomKTOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_kto(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = KTODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
|
||||
ref_model = model
|
||||
else:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomKTOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
|
@ -14,7 +14,7 @@ from .ppo import run_ppo
|
|||
from .pt import run_pt
|
||||
from .rm import run_rm
|
||||
from .sft import run_sft
|
||||
|
||||
from .kto import run_kto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
@ -39,6 +39,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb
|
|||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "orpo":
|
||||
run_orpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "kto":
|
||||
run_kto(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue