fix bug
This commit is contained in:
parent
747db40172
commit
8b681ee273
|
@ -66,7 +66,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
if reward_model is not None:
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
if self.is_deepspeed_enabled:
|
||||
if not (
|
||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||
|
|
Loading…
Reference in New Issue