diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index f922bbfd..598a34e5 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -209,24 +209,20 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("Distributed training does not support layer-wise GaLore.") -<<<<<<< HEAD - # if ( - # finetuning_args.use_badam - # and finetuning_args.badam_mode == "layer" - # and training_args.parallel_mode.value == "distributed" - # ): - # raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") -======= if ( finetuning_args.use_badam - and finetuning_args.badam_mode == "layer" - and training_args.parallel_mode == ParallelMode.DISTRIBUTED + and training_args.parallel_mode.value == "distributed" ): - raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") ->>>>>>> upstream/main + if finetuning_args.badam_mode == "ratio": + raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer") + if (finetuning_args.badam_mode == "layer" + and training_args.deepspeed_plugin is not None + and training_args.deepspeed_plugin.zero_stage < 3 + ): + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {self.args.deepspeed_plugin.zero_stage}") - if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: - raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") + if (finetuning_args.use_galore) and training_args.deepspeed is not None: + raise ValueError("GaLore are incompatible with DeepSpeed yet.") if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 9928d0bc..284bf41a 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -100,6 +100,12 @@ class CustomDPOTrainer(DPOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 91d68975..d8b609e0 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -95,6 +95,12 @@ class CustomKTOTrainer(KTOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index df4a37be..5a18cbaa 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -170,6 +170,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index f9e04cb5..1e5e9f6a 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -52,6 +52,12 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 7f91e5f5..5d0e6263 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -76,6 +76,12 @@ class PairwiseTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0628ea59..9446d245 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -57,9 +57,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) @@ -80,21 +85,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): if self.processor is not None: getattr(self.processor, "image_processor").save_pretrained(output_dir) - def training_step(self, *args, **kwargs): - r""" - Update the reference to deepspeed optimizer - """ - if self.finetuning_args.use_badam and \ - self.args.deepspeed_plugin is not None and \ - self.args.deepspeed_plugin.zero_stage == 3: - - ds_optim = self.optimizer.optimizer - badam_optim = ds_optim.optimizer - badam_optim.ds_optimizer = ds_optim - - return super().training_step(*args, **kwargs) - - def prediction_step( self, model: "torch.nn.Module",