support ppo score norm (trl 0.5.1.dev required)

This commit is contained in:
hiyouga 2023-08-18 12:02:42 +08:00
parent 9020524418
commit 53e33418d0
2 changed files with 10 additions and 0 deletions

View File

@ -61,6 +61,10 @@ class FinetuningArguments:
default=True, default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
) )
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO Training."}
)
dpo_beta: Optional[float] = field( dpo_beta: Optional[float] = field(
default=0.1, default=0.1,
metadata={"help": "The beta parameter for the DPO loss."} metadata={"help": "The beta parameter for the DPO loss."}

View File

@ -6,6 +6,7 @@ from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
@ -42,6 +43,11 @@ def run_ppo(
optimize_cuda_cache=True optimize_cuda_cache=True
) )
if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = ( total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size