diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index be8fca5d..ed00b437 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -42,7 +42,8 @@ def run_ppo( ppo_epochs=1, max_grad_norm=training_args.max_grad_norm, seed=training_args.seed, - optimize_cuda_cache=True + optimize_cuda_cache=True, + accelerator_kwargs={"step_scheduler_with_optimizer": False} ) if finetuning_args.ppo_score_norm: