diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 680559ac..c7c17ae0 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -215,11 +215,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): if finetuning_args.badam_mode == "ratio": raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer") - if (finetuning_args.badam_mode == "layer" - and training_args.deepspeed_plugin is not None - and training_args.deepspeed_plugin.zero_stage < 3 - ): - raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}") + if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()): + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.") if (finetuning_args.use_galore) and training_args.deepspeed is not None: raise ValueError("GaLore are incompatible with DeepSpeed yet.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 284bf41a..a3e0e961 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -96,15 +96,9 @@ class CustomDPOTrainer(DPOTrainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index d8b609e0..0d50987f 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -91,15 +91,9 @@ class CustomKTOTrainer(KTOTrainer): self.ref_model.eval() if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 5a18cbaa..f81831e6 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -166,15 +166,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1e5e9f6a..d3516b41 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -48,15 +48,9 @@ class CustomTrainer(Trainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 5d0e6263..433251cf 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -72,15 +72,9 @@ class PairwiseTrainer(Trainer): self.processor = processor self.can_return_loss = True # override property to return eval_loss if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 9446d245..45799b96 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -56,14 +56,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 98c38842..b57f5a6e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -371,11 +371,8 @@ def _create_badam_optimizer( dict(params=decay_params, weight_decay=training_args.weight_decay), ] - ds_zero3_enabled = False - if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: - assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}" - assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage" - ds_zero3_enabled = True + from transformers.integrations import is_deepspeed_zero3_enabled + ds_zero3_enabled = is_deepspeed_zero3_enabled() if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer @@ -400,6 +397,7 @@ def _create_badam_optimizer( elif finetuning_args.badam_mode == "ratio": from badam import BlockOptimizerRatio + assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer." assert finetuning_args.badam_update_ratio > 1e-6 optimizer = BlockOptimizerRatio( param_groups=param_groups, @@ -411,7 +409,7 @@ def _create_badam_optimizer( **optim_kwargs, ) logger.info( - f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"mask mode is {finetuning_args.badam_mask_mode}" )