support dpo-ftx
This commit is contained in:
parent
71389be37c
commit
b87c74289d
|
@ -3,7 +3,7 @@ transformers>=4.36.1
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.7.0
|
peft>=0.7.0
|
||||||
trl>=0.7.4
|
trl==0.7.4
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|
|
@ -13,48 +13,37 @@ def get_package_version(name: str) -> str:
|
||||||
return "0.0.0"
|
return "0.0.0"
|
||||||
|
|
||||||
|
|
||||||
_fastapi_available = is_package_available("fastapi")
|
|
||||||
_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
|
||||||
_jieba_available = is_package_available("jieba")
|
|
||||||
_matplotlib_available = is_package_available("matplotlib")
|
|
||||||
_nltk_available = is_package_available("nltk")
|
|
||||||
_requests_available = is_package_available("requests")
|
|
||||||
_rouge_available = is_package_available("rouge_chinese")
|
|
||||||
_starlette_available = is_package_available("sse_starlette")
|
|
||||||
_uvicorn_available = is_package_available("uvicorn")
|
|
||||||
|
|
||||||
|
|
||||||
def is_fastapi_availble():
|
def is_fastapi_availble():
|
||||||
return _fastapi_available
|
return is_package_available("fastapi")
|
||||||
|
|
||||||
|
|
||||||
def is_flash_attn2_available():
|
def is_flash_attn2_available():
|
||||||
return _flash_attn2_available
|
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return _jieba_available
|
return is_package_available("jieba")
|
||||||
|
|
||||||
|
|
||||||
def is_matplotlib_available():
|
def is_matplotlib_available():
|
||||||
return _matplotlib_available
|
return is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
def is_nltk_available():
|
def is_nltk_available():
|
||||||
return _nltk_available
|
return is_package_available("nltk")
|
||||||
|
|
||||||
|
|
||||||
def is_requests_available():
|
def is_requests_available():
|
||||||
return _requests_available
|
return is_package_available("requests")
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return _rouge_available
|
return is_package_available("rouge_chinese")
|
||||||
|
|
||||||
|
|
||||||
def is_starlette_available():
|
def is_starlette_available():
|
||||||
return _starlette_available
|
return is_package_available("sse_starlette")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _uvicorn_available
|
return is_package_available("uvicorn")
|
||||||
|
|
|
@ -70,6 +70,14 @@ class RLHFArguments:
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
|
dpo_loss: Optional[Literal["sigmoid", "hinge"]] = field(
|
||||||
|
default="sigmoid",
|
||||||
|
metadata={"help": "The type of DPO loss to use."}
|
||||||
|
)
|
||||||
|
dpo_ftx: Optional[float] = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
|
||||||
|
)
|
||||||
ppo_buffer_size: Optional[int] = field(
|
ppo_buffer_size: Optional[int] = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
|
||||||
|
|
|
@ -24,7 +24,7 @@ require_version("transformers>=4.36.1", "To fix: pip install transformers>=4.36.
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
require_version("trl==0.7.4", "To fix: pip install trl==0.7.4")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|
|
@ -16,10 +16,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
beta: float,
|
beta: float,
|
||||||
|
loss_type: Literal["sigmoid", "hinge"],
|
||||||
|
ftx_gamma: float,
|
||||||
model: Union["PreTrainedModel", torch.nn.Module],
|
model: Union["PreTrainedModel", torch.nn.Module],
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
disable_dropout: Optional[bool] = True,
|
disable_dropout: Optional[bool] = True,
|
||||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if disable_dropout:
|
if disable_dropout:
|
||||||
|
@ -34,6 +35,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
self.padding_value = 0
|
self.padding_value = 0
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
self.label_smoothing = 0
|
||||||
|
self.ftx_gamma = ftx_gamma
|
||||||
self.loss_type = loss_type
|
self.loss_type = loss_type
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
@ -51,10 +54,28 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
else:
|
else:
|
||||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def sft_loss(
|
||||||
|
self,
|
||||||
|
chosen_logits: torch.FloatTensor,
|
||||||
|
chosen_labels: torch.LongTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
Computes supervised cross-entropy loss of given labels under the given logits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||||
|
"""
|
||||||
|
all_logps = self._get_batch_logps(
|
||||||
|
chosen_logits,
|
||||||
|
chosen_labels,
|
||||||
|
average_log_prob=True
|
||||||
|
)
|
||||||
|
return -all_logps
|
||||||
|
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self,
|
self,
|
||||||
model: Optional[torch.nn.Module] = None,
|
model: "PreTrainedModel",
|
||||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
batch: Dict[str, torch.Tensor]
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
|
|
||||||
|
@ -73,3 +94,61 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
|
|
||||||
|
def get_batch_metrics(
|
||||||
|
self,
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
batch: Dict[str, torch.Tensor],
|
||||||
|
train_eval: Optional[Literal["train", "eval"]] = "train"
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
r"""
|
||||||
|
Computes the DPO loss and other metrics for the given batch of inputs for train or test.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
) = self.concatenated_forward(model, batch)
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.ref_model is None:
|
||||||
|
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||||
|
(
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.concatenated_forward(self.model, batch)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.concatenated_forward(self.ref_model, batch)
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
)
|
||||||
|
if self.ftx_gamma > 1e-6:
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
|
||||||
|
losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
|
||||||
|
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
|
||||||
|
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
|
||||||
|
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
|
||||||
|
|
||||||
|
return losses.mean(), metrics
|
||||||
|
|
|
@ -47,6 +47,8 @@ def run_dpo(
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
|
loss_type=finetuning_args.dpo_loss,
|
||||||
|
ftx_gamma=finetuning_args.dpo_ftx,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|
Loading…
Reference in New Issue