From d3dccd0693ede18a99f04780f2fd6e3a89810405 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 4 Dec 2023 19:00:19 +0800 Subject: [PATCH] fix ppo trainer save logic --- src/llmtuner/train/ppo/trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index f02dbdc3..40129840 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits @@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) except ValueError: logger.warning( - " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" - " zero_to_fp32.py to recover weights" + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," + " use zero_to_fp32.py to recover weights" ) self._save(output_dir, state_dict={}) - remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + self.model.save_checkpoint(output_dir) # wrapped model