diff --git a/README.md b/README.md index 535deefb..ceb84cb1 100644 --- a/README.md +++ b/README.md @@ -313,6 +313,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ + --top_k 0 \ + --top_p 0.9 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate 1e-5 \ diff --git a/README_zh.md b/README_zh.md index 10418c3d..740dab0a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -313,6 +313,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --per_device_train_batch_size 2 \ --gradient_accumulation_steps 4 \ --lr_scheduler_type cosine \ + --top_k 0 \ + --top_p 0.9 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate 1e-5 \ diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index cfdc8b24..d39812c7 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -74,6 +74,10 @@ class RLHFArguments: default=None, metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} ) + ppo_epochs: Optional[int] = field( + default=4, + metadata={"help": "Number of optimisation epochs per batch of samples"}, + ) ppo_score_norm: Optional[bool] = field( default=False, metadata={"help": "Use score normalization in PPO training."} diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 2a779edd..3e7d8053 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -45,7 +45,7 @@ def run_ppo( mini_batch_size=training_args.per_device_train_batch_size, batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps, - ppo_epochs=1, + ppo_epochs=finetuning_args.ppo_epochs, max_grad_norm=training_args.max_grad_norm, seed=training_args.seed, optimize_device_cache=True,