From 33b437277846d4f0b64c13a0bc892ef4f345a21e Mon Sep 17 00:00:00 2001 From: Jonery Date: Mon, 17 Jun 2024 18:18:10 +0800 Subject: [PATCH 1/6] adapt for badam with ds zero3 --- src/llamafactory/hparams/parser.py | 12 ++++++------ src/llamafactory/train/sft/trainer.py | 15 +++++++++++++++ src/llamafactory/train/utils.py | 7 +++++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 6311297e..fe108657 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -184,12 +184,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("Distributed training does not support layer-wise GaLore.") - if ( - finetuning_args.use_badam - and finetuning_args.badam_mode == "layer" - and training_args.parallel_mode.value == "distributed" - ): - raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") + # if ( + # finetuning_args.use_badam + # and finetuning_args.badam_mode == "layer" + # and training_args.parallel_mode.value == "distributed" + # ): + # raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 35671e1b..cd73bf5c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -55,6 +55,21 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): output_dir = output_dir if output_dir is not None else self.args.output_dir getattr(self.processor, "image_processor").save_pretrained(output_dir) + def training_step(self, *args, **kwargs): + r""" + Update the reference to deepspeed optimizer + """ + if self.finetuning_args.use_badam and \ + self.args.deepspeed_plugin is not None and \ + self.args.deepspeed_plugin.zero_stage == 3: + + ds_optim = self.optimizer.optimizer + badam_optim = ds_optim.optimizer + badam_optim.ds_optimizer = ds_optim + + return super().training_step(*args, **kwargs) + + def prediction_step( self, model: "torch.nn.Module", diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/utils.py index 23834f2d..b189922b 100644 --- a/src/llamafactory/train/utils.py +++ b/src/llamafactory/train/utils.py @@ -309,6 +309,12 @@ def _create_badam_optimizer( dict(params=decay_params, weight_decay=training_args.weight_decay), ] + ds_zero3_enabled = False + if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: + assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}" + assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage" + ds_zero3_enabled = True + if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer @@ -321,6 +327,7 @@ def _create_badam_optimizer( start_block=finetuning_args.badam_start_block, switch_mode=finetuning_args.badam_switch_mode, verbose=finetuning_args.badam_verbose, + ds_zero3_enabled=ds_zero3_enabled ) logger.info( f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " From b2fc9cc15f2fc3c9536681260beedd3a224ef40f Mon Sep 17 00:00:00 2001 From: Jonery Date: Mon, 17 Jun 2024 18:29:36 +0800 Subject: [PATCH 2/6] update gitigore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 0355c666..2486e728 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,5 @@ cython_debug/ user.config saves/ cache/ +wandb +ds_badam_exp \ No newline at end of file From 0f72aac8c9227e33ad20d2b1641b1c9faae16a5f Mon Sep 17 00:00:00 2001 From: Jonery Date: Tue, 18 Jun 2024 12:27:47 +0800 Subject: [PATCH 3/6] Support distributed BAdam. --- src/llamafactory/hparams/parser.py | 24 ++++++++++-------------- src/llamafactory/train/dpo/trainer.py | 6 ++++++ src/llamafactory/train/kto/trainer.py | 6 ++++++ src/llamafactory/train/ppo/trainer.py | 6 ++++++ src/llamafactory/train/pt/trainer.py | 6 ++++++ src/llamafactory/train/rm/trainer.py | 6 ++++++ src/llamafactory/train/sft/trainer.py | 22 ++++++---------------- 7 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index f922bbfd..598a34e5 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -209,24 +209,20 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): raise ValueError("Distributed training does not support layer-wise GaLore.") -<<<<<<< HEAD - # if ( - # finetuning_args.use_badam - # and finetuning_args.badam_mode == "layer" - # and training_args.parallel_mode.value == "distributed" - # ): - # raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") -======= if ( finetuning_args.use_badam - and finetuning_args.badam_mode == "layer" - and training_args.parallel_mode == ParallelMode.DISTRIBUTED + and training_args.parallel_mode.value == "distributed" ): - raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") ->>>>>>> upstream/main + if finetuning_args.badam_mode == "ratio": + raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer") + if (finetuning_args.badam_mode == "layer" + and training_args.deepspeed_plugin is not None + and training_args.deepspeed_plugin.zero_stage < 3 + ): + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {self.args.deepspeed_plugin.zero_stage}") - if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None: - raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.") + if (finetuning_args.use_galore) and training_args.deepspeed is not None: + raise ValueError("GaLore are incompatible with DeepSpeed yet.") if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 9928d0bc..284bf41a 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -100,6 +100,12 @@ class CustomDPOTrainer(DPOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 91d68975..d8b609e0 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -95,6 +95,12 @@ class CustomKTOTrainer(KTOTrainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index df4a37be..5a18cbaa 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -170,6 +170,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index f9e04cb5..1e5e9f6a 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -52,6 +52,12 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 7f91e5f5..5d0e6263 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -76,6 +76,12 @@ class PairwiseTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0628ea59..9446d245 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -57,9 +57,14 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + if (self.args.deepspeed_plugin is not None + and self.args.deepspeed_plugin.zero_stage == 3 + ): + from badam.utils import BAdamZeRO3Callback + self.callback_handler.add_callback(BAdamZeRO3Callback) + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) @@ -80,21 +85,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): if self.processor is not None: getattr(self.processor, "image_processor").save_pretrained(output_dir) - def training_step(self, *args, **kwargs): - r""" - Update the reference to deepspeed optimizer - """ - if self.finetuning_args.use_badam and \ - self.args.deepspeed_plugin is not None and \ - self.args.deepspeed_plugin.zero_stage == 3: - - ds_optim = self.optimizer.optimizer - badam_optim = ds_optim.optimizer - badam_optim.ds_optimizer = ds_optim - - return super().training_step(*args, **kwargs) - - def prediction_step( self, model: "torch.nn.Module", From 8f7c78b64138602406af748b0e15948ebbd2dcb5 Mon Sep 17 00:00:00 2001 From: Jonery Date: Tue, 18 Jun 2024 12:39:26 +0800 Subject: [PATCH 4/6] fix typo --- src/llamafactory/hparams/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 598a34e5..680559ac 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -219,7 +219,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and training_args.deepspeed_plugin is not None and training_args.deepspeed_plugin.zero_stage < 3 ): - raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {self.args.deepspeed_plugin.zero_stage}") + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}") if (finetuning_args.use_galore) and training_args.deepspeed is not None: raise ValueError("GaLore are incompatible with DeepSpeed yet.") From 97c523516093961983037922e7fc84e4010d5fec Mon Sep 17 00:00:00 2001 From: Jonery Date: Tue, 18 Jun 2024 13:50:26 +0800 Subject: [PATCH 5/6] add example --- examples/extras/badam/llama3_badam_sft.yaml | 40 +++++++++++++++++++++ examples/extras/badam/train_single_gpu.sh | 37 +++++++++++++++++++ examples/extras/badam/train_zero3.sh | 39 ++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 examples/extras/badam/llama3_badam_sft.yaml create mode 100644 examples/extras/badam/train_single_gpu.sh create mode 100644 examples/extras/badam/train_zero3.sh diff --git a/examples/extras/badam/llama3_badam_sft.yaml b/examples/extras/badam/llama3_badam_sft.yaml new file mode 100644 index 00000000..f5adb220 --- /dev/null +++ b/examples/extras/badam/llama3_badam_sft.yaml @@ -0,0 +1,40 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + +### method +stage: sft +do_train: true +finetuning_type: full +use_badam: true +badam_switch_mode: ascending +badam_switch_interval: 50 +badam_verbose: 2 + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/full/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 1.0e-6 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/examples/extras/badam/train_single_gpu.sh b/examples/extras/badam/train_single_gpu.sh new file mode 100644 index 00000000..8af79007 --- /dev/null +++ b/examples/extras/badam/train_single_gpu.sh @@ -0,0 +1,37 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 + +cd ../../.. + +llamafactory-cli train \ + --stage sft \ + --do_train True \ + --model_name_or_path meta-llama/Llama-2-13b-hf \ + --preprocessing_num_workers 16 \ + --finetuning_type full \ + --template default \ + --flash_attn auto \ + --dataset_dir data \ + --dataset alpaca_en_demo \ + --cutoff_len 1024 \ + --learning_rate 1e-6 \ + --num_train_epochs 3.0 \ + --max_samples 100000 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --max_grad_norm 1.0 \ + --logging_steps 5 \ + --save_steps 100 \ + --warmup_steps 0 \ + --optim adamw_torch \ + --packing False \ + --report_to none \ + --use_badam True \ + --output_dir saves/LLaMA2-13B/full/BAdam \ + --plot_loss True \ + --ddp_timeout 180000000 \ + --include_num_input_tokens_seen True \ + --badam_mode layer \ + --badam_switch_mode ascending \ + --badam_switch_interval 50 \ No newline at end of file diff --git a/examples/extras/badam/train_zero3.sh b/examples/extras/badam/train_zero3.sh new file mode 100644 index 00000000..3b182134 --- /dev/null +++ b/examples/extras/badam/train_zero3.sh @@ -0,0 +1,39 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +cd ../../.. + +llamafactory-cli train \ + --stage sft \ + --do_train True \ + --model_name_or_path meta-llama/Llama-2-13b-hf \ + --preprocessing_num_workers 16 \ + --finetuning_type full \ + --template default \ + --flash_attn auto \ + --dataset_dir data \ + --dataset alpaca_en_demo \ + --cutoff_len 1024 \ + --learning_rate 1e-6 \ + --num_train_epochs 3.0 \ + --max_samples 100000 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --lr_scheduler_type cosine \ + --max_grad_norm 1.0 \ + --logging_steps 5 \ + --save_steps 100 \ + --warmup_steps 0 \ + --optim adamw_torch \ + --packing False \ + --report_to none \ + --use_badam True \ + --output_dir saves/LLaMA2-13B/full/BAdam \ + --fp16 True \ + --plot_loss True \ + --ddp_timeout 180000000 \ + --include_num_input_tokens_seen True \ + --badam_mode layer \ + --badam_switch_mode ascending \ + --badam_switch_interval 50 \ + --deepspeed cache/ds_z3_config.json \ No newline at end of file From 5c2ff1b749a265dd3c979189ec491d8ac911a6f6 Mon Sep 17 00:00:00 2001 From: Jonery Date: Wed, 19 Jun 2024 12:29:40 +0800 Subject: [PATCH 6/6] Cleaner integration. --- src/llamafactory/hparams/parser.py | 7 ++----- src/llamafactory/train/dpo/trainer.py | 12 +++--------- src/llamafactory/train/kto/trainer.py | 12 +++--------- src/llamafactory/train/ppo/trainer.py | 12 +++--------- src/llamafactory/train/pt/trainer.py | 12 +++--------- src/llamafactory/train/rm/trainer.py | 12 +++--------- src/llamafactory/train/sft/trainer.py | 11 +++-------- src/llamafactory/train/trainer_utils.py | 10 ++++------ 8 files changed, 24 insertions(+), 64 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 680559ac..c7c17ae0 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -215,11 +215,8 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: ): if finetuning_args.badam_mode == "ratio": raise ValueError("Ratio-wise BAdam does not yet support distributed training, use layer-wise BAdam: --badam_mode layer") - if (finetuning_args.badam_mode == "layer" - and training_args.deepspeed_plugin is not None - and training_args.deepspeed_plugin.zero_stage < 3 - ): - raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage, got stage {training_args.deepspeed_plugin.zero_stage}") + if finetuning_args.badam_mode == "layer" and (not is_deepspeed_zero3_enabled()): + raise ValueError(f"Layer-wise BAdam only supports DeepSpeed ZeRO 3 stage.") if (finetuning_args.use_galore) and training_args.deepspeed is not None: raise ValueError("GaLore are incompatible with DeepSpeed yet.") diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 284bf41a..a3e0e961 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -96,15 +96,9 @@ class CustomDPOTrainer(DPOTrainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index d8b609e0..0d50987f 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -91,15 +91,9 @@ class CustomKTOTrainer(KTOTrainer): self.ref_model.eval() if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 5a18cbaa..f81831e6 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -166,15 +166,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: r""" diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 1e5e9f6a..d3516b41 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -48,15 +48,9 @@ class CustomTrainer(Trainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 5d0e6263..433251cf 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -72,15 +72,9 @@ class PairwiseTrainer(Trainer): self.processor = processor self.can_return_loss = True # override property to return eval_loss if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 9446d245..45799b96 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -56,14 +56,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: - from badam import clip_grad_norm_for_sparse_tensor - self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) - - if (self.args.deepspeed_plugin is not None - and self.args.deepspeed_plugin.zero_stage == 3 - ): - from badam.utils import BAdamZeRO3Callback - self.callback_handler.add_callback(BAdamZeRO3Callback) + from badam import clip_grad_norm_old_version, BAdamCallback + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.callback_handler.add_callback(BAdamCallback) def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 98c38842..b57f5a6e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -371,11 +371,8 @@ def _create_badam_optimizer( dict(params=decay_params, weight_decay=training_args.weight_decay), ] - ds_zero3_enabled = False - if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None: - assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}" - assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage" - ds_zero3_enabled = True + from transformers.integrations import is_deepspeed_zero3_enabled + ds_zero3_enabled = is_deepspeed_zero3_enabled() if finetuning_args.badam_mode == "layer": from badam import BlockOptimizer @@ -400,6 +397,7 @@ def _create_badam_optimizer( elif finetuning_args.badam_mode == "ratio": from badam import BlockOptimizerRatio + assert not ds_zero3_enabled, "BAdam with ratio-based update does not support Deepspeed ZeRO-3 yet, use layer-wise update instead: --badam_mode layer." assert finetuning_args.badam_update_ratio > 1e-6 optimizer = BlockOptimizerRatio( param_groups=param_groups, @@ -411,7 +409,7 @@ def _create_badam_optimizer( **optim_kwargs, ) logger.info( - f"Using BAdam optimizer with ratio-wise update, update ratio is {finetuning_args.badam_update_ratio}, " + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"mask mode is {finetuning_args.badam_mask_mode}" )