support dpo-ftx

This commit is contained in:
hiyouga 2023-12-16 19:21:41 +08:00
parent 71389be37c
commit b87c74289d
6 changed files with 103 additions and 25 deletions

View File

@ -3,7 +3,7 @@ transformers>=4.36.1
datasets>=2.14.3
accelerate>=0.21.0
peft>=0.7.0
trl>=0.7.4
trl==0.7.4
gradio>=3.38.0,<4.0.0
scipy
sentencepiece

View File

@ -13,48 +13,37 @@ def get_package_version(name: str) -> str:
return "0.0.0"
_fastapi_available = is_package_available("fastapi")
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
_jieba_available = is_package_available("jieba")
_matplotlib_available = is_package_available("matplotlib")
_nltk_available = is_package_available("nltk")
_requests_available = is_package_available("requests")
_rouge_available = is_package_available("rouge_chinese")
_starlette_available = is_package_available("sse_starlette")
_uvicorn_available = is_package_available("uvicorn")
def is_fastapi_availble():
return _fastapi_available
return is_package_available("fastapi")
def is_flash_attn2_available():
return _flash_attn2_available
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
def is_jieba_available():
return _jieba_available
return is_package_available("jieba")
def is_matplotlib_available():
return _matplotlib_available
return is_package_available("matplotlib")
def is_nltk_available():
return _nltk_available
return is_package_available("nltk")
def is_requests_available():
return _requests_available
return is_package_available("requests")
def is_rouge_available():
return _rouge_available
return is_package_available("rouge_chinese")
def is_starlette_available():
return _starlette_available
return is_package_available("sse_starlette")
def is_uvicorn_available():
return _uvicorn_available
return is_package_available("uvicorn")

View File

@ -70,6 +70,14 @@ class RLHFArguments:
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
)
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
default="sigmoid",
metadata={"help": "The type of DPO loss to use."}
)
dpo_ftx: Optional[float] = field(
default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
)
ppo_buffer_size: Optional[int] = field(
default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}

View File

@ -24,7 +24,7 @@ require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4")
def load_model_and_tokenizer(

View File

@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
beta: float,
loss_type: Literal["sigmoid", "hinge"],
ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
**kwargs
):
if disable_dropout:
@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer):
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = beta
self.label_smoothing = 0
self.ftx_gamma = ftx_gamma
self.loss_type = loss_type
self._stored_metrics = defaultdict(lambda: defaultdict(list))
@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def sft_loss(
self,
chosen_logits: torch.FloatTensor,
chosen_labels: torch.LongTensor
) -> torch.Tensor:
r"""
Computes supervised cross-entropy loss of given labels under the given logits.
Returns:
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
"""
all_logps = self._get_batch_logps(
chosen_logits,
chosen_labels,
average_log_prob=True
)
return -all_logps
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer):
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
def get_batch_metrics(
self,
model: "PreTrainedModel",
batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train"
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
"""
metrics = {}
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
if self.ftx_gamma > 1e-6:
batch_size = batch["input_ids"].size(0) // 2
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
return losses.mean(), metrics

View File

@ -47,6 +47,8 @@ def run_dpo(
# Initialize our Trainer
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,
ftx_gamma=finetuning_args.dpo_ftx,
model=model,
ref_model=ref_model,
args=training_args,