diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 8f18317f..f0a86dff 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -53,6 +53,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.processor = processor if finetuning_args.pissa_convert: + if self.is_deepspeed_enabled: + self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config + self.deepspeed = self._wrap_model(self.model_wrapped) self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: