fix ppo trainer save logic
This commit is contained in:
parent
997b65f291
commit
d3dccd0693
|
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||
" zero_to_fp32.py to recover weights"
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||
" use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
||||
file = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
|
||||
self.model.save_checkpoint(output_dir) # wrapped model
|
||||
|
|
Loading…
Reference in New Issue