From 53e33418d02ee0f34c783e30ae510b811308c598 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 18 Aug 2023 12:02:42 +0800 Subject: [PATCH] support ppo score norm (trl 0.5.1.dev required) --- src/llmtuner/hparams/finetuning_args.py | 4 ++++ src/llmtuner/tuner/ppo/workflow.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 98598dd5..5af4549e 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -61,6 +61,10 @@ class FinetuningArguments: default=True, metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} ) + ppo_score_norm: Optional[bool] = field( + default=False, + metadata={"help": "Use score normalization in PPO Training."} + ) dpo_beta: Optional[float] = field( default=0.1, metadata={"help": "The beta parameter for the DPO loss."} diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 12fcdef1..c243e322 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -6,6 +6,7 @@ from torch.optim import AdamW from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq from transformers.optimization import get_scheduler +from transformers.utils.versions import require_version from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.extras.ploting import plot_loss @@ -42,6 +43,11 @@ def run_ppo( optimize_cuda_cache=True ) + if finetuning_args.ppo_score_norm: + require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git") + ppo_config.use_score_scaling = True + ppo_config.use_score_norm = True + optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) total_train_batch_size = ( training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size