fix #4579
This commit is contained in:
parent
96a5044394
commit
677c86594e
|
@ -53,6 +53,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||
self.processor = processor
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
self.deepspeed = self._wrap_model(self.model_wrapped)
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
|
|
Loading…
Reference in New Issue