fix dpo trainer
This commit is contained in:
parent
9a18a85639
commit
074745b170
|
@ -130,6 +130,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
Returns:
|
||||
A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
|
||||
"""
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logps = self.get_batch_logps(
|
||||
chosen_logits,
|
||||
chosen_labels,
|
||||
average_log_prob=True
|
||||
|
@ -89,7 +89,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
|
|
Loading…
Reference in New Issue