fix bug in galore optimizer

This commit is contained in:
hiyouga 2024-04-21 18:53:22 +08:00
parent f58425ab45
commit 5c62881c5a
2 changed files with 7 additions and 13 deletions

View File

@ -11,8 +11,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_scale 2.0 \
--galore_rank 128 \
--galore_scale 2.0 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
@ -29,8 +29,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 30.0 \
--max_samples 300 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -234,14 +234,6 @@ def _create_galore_optimizer(
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
else:
param_groups = [
@ -391,9 +383,11 @@ def create_custom_scheduler(
num_training_steps=num_training_steps * 2,
)
def scheduler_hook(param: "torch.nn.Parameter"):
def optimizer_hook(param: "torch.nn.Parameter"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in optimizer_dict.keys():
param.register_post_accumulate_grad_hook(scheduler_hook)
param.register_post_accumulate_grad_hook(optimizer_hook)