fix ppo trainer

This commit is contained in:
hiyouga 2024-07-10 11:05:45 +08:00
parent 2f09520c0d
commit fb0c400116
1 changed files with 2 additions and 1 deletions

View File

@ -106,7 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None
# Create optimizer and scheduler