From 5c62881c5a59cfcc5a76d365263c8ad8c817ce49 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 21 Apr 2024 18:53:22 +0800 Subject: [PATCH] fix bug in galore optimizer --- examples/extras/galore/sft.sh | 6 +++--- src/llmtuner/train/utils.py | 14 ++++---------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/examples/extras/galore/sft.sh b/examples/extras/galore/sft.sh index 1e46ac1f..da1779ed 100644 --- a/examples/extras/galore/sft.sh +++ b/examples/extras/galore/sft.sh @@ -11,8 +11,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --use_galore \ --galore_layerwise \ --galore_target mlp,self_attn \ - --galore_scale 2.0 \ --galore_rank 128 \ + --galore_scale 2.0 \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \ --overwrite_cache \ --overwrite_output_dir \ @@ -29,8 +29,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --evaluation_strategy steps \ --load_best_model_at_end \ --learning_rate 5e-5 \ - --num_train_epochs 30.0 \ - --max_samples 300 \ + --num_train_epochs 3.0 \ + --max_samples 3000 \ --val_size 0.1 \ --plot_loss \ --pure_bf16 diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 2835eddf..d3f17116 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -234,14 +234,6 @@ def _create_galore_optimizer( param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) - def optimizer_hook(param: "torch.nn.Parameter"): - if param.grad is not None: - optimizer_dict[param].step() - optimizer_dict[param].zero_grad() - - for param in trainable_params: - param.register_post_accumulate_grad_hook(optimizer_hook) - optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) else: param_groups = [ @@ -391,9 +383,11 @@ def create_custom_scheduler( num_training_steps=num_training_steps * 2, ) - def scheduler_hook(param: "torch.nn.Parameter"): + def optimizer_hook(param: "torch.nn.Parameter"): if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() scheduler_dict[param].step() for param in optimizer_dict.keys(): - param.register_post_accumulate_grad_hook(scheduler_hook) + param.register_post_accumulate_grad_hook(optimizer_hook)