From 8845e94f917b503bbee0604d7290efea7260a30c Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 3 Jul 2024 19:45:51 +0800 Subject: [PATCH] fix #4609 unwrap_model_for_generation(reward_model) is necessary for zero3 training --- src/llamafactory/train/ppo/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 37d9d37e..6a05b704 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -393,7 +393,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): else: reward_model = self.reward_model - with self.amp_context: # support bf16 + with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 _, _, values = reward_model(**batch, return_dict=True, use_cache=False) if self.finetuning_args.reward_model_type == "lora": @@ -496,4 +496,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.model.save_checkpoint(output_dir) elif self.args.should_save: - self._save(output_dir) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + self._save(output_dir, state_dict=unwrapped_model.state_dict())