support ppo score norm (trl 0.5.1.dev required)
This commit is contained in:
parent
9020524418
commit
53e33418d0
|
@ -61,6 +61,10 @@ class FinetuningArguments:
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
)
|
)
|
||||||
|
ppo_score_norm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use score normalization in PPO Training."}
|
||||||
|
)
|
||||||
dpo_beta: Optional[float] = field(
|
dpo_beta: Optional[float] = field(
|
||||||
default=0.1,
|
default=0.1,
|
||||||
metadata={"help": "The beta parameter for the DPO loss."}
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torch.optim import AdamW
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
@ -42,6 +43,11 @@ def run_ppo(
|
||||||
optimize_cuda_cache=True
|
optimize_cuda_cache=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if finetuning_args.ppo_score_norm:
|
||||||
|
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
|
||||||
|
ppo_config.use_score_scaling = True
|
||||||
|
ppo_config.use_score_norm = True
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
total_train_batch_size = (
|
total_train_batch_size = (
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
|
Loading…
Reference in New Issue