From e874c00906c765b81c0e5ff9c7b3679557da8e0e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 11 Mar 2024 00:42:54 +0800 Subject: [PATCH] fix #2775 --- src/llmtuner/train/utils.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 75006ee0..425ff18e 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -1,5 +1,5 @@ import math -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union import torch from transformers.optimization import get_scheduler @@ -151,16 +151,18 @@ def create_custom_optimzer( return None require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git") - galore_params = [] + galore_params: List[torch.nn.Parameter] = [] galore_targets = finetuning_args.galore_target.split(",") for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): - galore_params += list(filter(lambda p: p.requires_grad, module.parameters())) + for param in module.parameters(): + if param.requires_grad and len(param.shape) > 1: + galore_params.append(param) - id_galore_params = [id(p) for p in galore_params] - trainable_params = filter(lambda p: p.requires_grad, model.parameters()) - non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params] + id_galore_params = {id(param) for param in galore_params} + trainable_params = filter(lambda param: param.requires_grad, model.parameters()) + non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params] if training_args.optim == "adamw_torch": optim_class = GaLoreAdamW @@ -168,6 +170,7 @@ def create_custom_optimzer( "lr": training_args.learning_rate, "eps": training_args.adam_epsilon, "betas": (training_args.adam_beta1, training_args.adam_beta2), + "weight_decay": training_args.weight_decay, } elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: @@ -176,6 +179,7 @@ def create_custom_optimzer( "lr": training_args.learning_rate, "eps": training_args.adam_epsilon, "betas": (training_args.adam_beta1, training_args.adam_beta2), + "weight_decay": training_args.weight_decay, "optim_bits": 8, "is_paged": "paged" in training_args.optim, } @@ -184,6 +188,7 @@ def create_custom_optimzer( optim_class = GaLoreAdafactor optim_kwargs = { "lr": training_args.learning_rate, + "weight_decay": training_args.weight_decay, } else: