Merge pull request #4580 from hzhaoy/bugfix-deepspeed-pissa

Fix bug when using pissa method with deepspeed
This commit is contained in:
hoshi-hiyouga 2024-06-28 00:46:51 +08:00 committed by GitHub
commit ef38daa0a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 0 deletions

View File

@ -53,6 +53,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.processor = processor
if finetuning_args.pissa_convert:
if self.is_deepspeed_enabled:
self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
self.deepspeed = self._wrap_model(self.model_wrapped)
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
if finetuning_args.use_badam: