support dpo-ftx
This commit is contained in:
parent
71389be37c
commit
b87c74289d
|
@ -3,7 +3,7 @@ transformers>=4.36.1
|
|||
datasets>=2.14.3
|
||||
accelerate>=0.21.0
|
||||
peft>=0.7.0
|
||||
trl>=0.7.4
|
||||
trl==0.7.4
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
sentencepiece
|
||||
|
|
|
@ -13,48 +13,37 @@ def get_package_version(name: str) -> str:
|
|||
return "0.0.0"
|
||||
|
||||
|
||||
_fastapi_available = is_package_available("fastapi")
|
||||
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
_jieba_available = is_package_available("jieba")
|
||||
_matplotlib_available = is_package_available("matplotlib")
|
||||
_nltk_available = is_package_available("nltk")
|
||||
_requests_available = is_package_available("requests")
|
||||
_rouge_available = is_package_available("rouge_chinese")
|
||||
_starlette_available = is_package_available("sse_starlette")
|
||||
_uvicorn_available = is_package_available("uvicorn")
|
||||
|
||||
|
||||
def is_fastapi_availble():
|
||||
return _fastapi_available
|
||||
return is_package_available("fastapi")
|
||||
|
||||
|
||||
def is_flash_attn2_available():
|
||||
return _flash_attn2_available
|
||||
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _jieba_available
|
||||
return is_package_available("jieba")
|
||||
|
||||
|
||||
def is_matplotlib_available():
|
||||
return _matplotlib_available
|
||||
return is_package_available("matplotlib")
|
||||
|
||||
|
||||
def is_nltk_available():
|
||||
return _nltk_available
|
||||
return is_package_available("nltk")
|
||||
|
||||
|
||||
def is_requests_available():
|
||||
return _requests_available
|
||||
return is_package_available("requests")
|
||||
|
||||
|
||||
def is_rouge_available():
|
||||
return _rouge_available
|
||||
return is_package_available("rouge_chinese")
|
||||
|
||||
|
||||
def is_starlette_available():
|
||||
return _starlette_available
|
||||
return is_package_available("sse_starlette")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _uvicorn_available
|
||||
return is_package_available("uvicorn")
|
||||
|
|
|
@ -70,6 +70,14 @@ class RLHFArguments:
|
|||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."}
|
||||
)
|
||||
dpo_ftx: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||
)
|
||||
ppo_buffer_size: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
||||
|
|
|
@ -24,7 +24,7 @@ require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.
|
|||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
|
|
@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
|
@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.label_smoothing = 0
|
||||
self.ftx_gamma = ftx_gamma
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
|
@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def sft_loss(
|
||||
self,
|
||||
chosen_logits: torch.FloatTensor,
|
||||
chosen_labels: torch.LongTensor
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self._get_batch_logps(
|
||||
chosen_logits,
|
||||
chosen_labels,
|
||||
average_log_prob=True
|
||||
)
|
||||
return -all_logps
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
|
@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
|
||||
def get_batch_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
batch: Dict[str, torch.Tensor],
|
||||
train_eval: Optional[Literal["train", "eval"]] = "train"
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
r"""
|
||||
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||
"""
|
||||
metrics = {}
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.model, batch)
|
||||
else:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
if self.ftx_gamma > 1e-6:
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||
|
||||
return losses.mean(), metrics
|
||||
|
|
|
@ -47,6 +47,8 @@ def run_dpo(
|
|||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
ftx_gamma=finetuning_args.dpo_ftx,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
|
|
Loading…
Reference in New Issue