forked from p04798526/LLaMA-Factory-Mirror
fix bug in galore optimizer
This commit is contained in:
parent
f58425ab45
commit
5c62881c5a
|
@ -11,8 +11,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--use_galore \
|
||||
--galore_layerwise \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_scale 2.0 \
|
||||
--galore_rank 128 \
|
||||
--galore_scale 2.0 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
@ -29,8 +29,8 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 30.0 \
|
||||
--max_samples 300 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--pure_bf16
|
||||
|
|
|
@ -234,14 +234,6 @@ def _create_galore_optimizer(
|
|||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in trainable_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
|
@ -391,9 +383,11 @@ def create_custom_scheduler(
|
|||
num_training_steps=num_training_steps * 2,
|
||||
)
|
||||
|
||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
|
Loading…
Reference in New Issue