support ppo score norm (trl 0.5.1.dev required)
This commit is contained in:
parent
9020524418
commit
53e33418d0
|
@ -61,6 +61,10 @@ class FinetuningArguments:
|
|||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO Training."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
|
|
|
@ -6,6 +6,7 @@ from torch.optim import AdamW
|
|||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
|
@ -42,6 +43,11 @@ def run_ppo(
|
|||
optimize_cuda_cache=True
|
||||
)
|
||||
|
||||
if finetuning_args.ppo_score_norm:
|
||||
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
|
||||
ppo_config.use_score_scaling = True
|
||||
ppo_config.use_score_norm = True
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
|
|
Loading…
Reference in New Issue