This commit is contained in:
hiyouga 2024-03-11 00:42:54 +08:00
parent 352693e2dc
commit e874c00906
1 changed files with 11 additions and 6 deletions

View File

@ -1,5 +1,5 @@
import math import math
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import torch import torch
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
@ -151,16 +151,18 @@ def create_custom_optimzer(
return None return None
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git") require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
galore_params = [] galore_params: List[torch.nn.Parameter] = []
galore_targets = finetuning_args.galore_target.split(",") galore_targets = finetuning_args.galore_target.split(",")
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
galore_params += list(filter(lambda p: p.requires_grad, module.parameters())) for param in module.parameters():
if param.requires_grad and len(param.shape) > 1:
galore_params.append(param)
id_galore_params = [id(p) for p in galore_params] id_galore_params = {id(param) for param in galore_params}
trainable_params = filter(lambda p: p.requires_grad, model.parameters()) trainable_params = filter(lambda param: param.requires_grad, model.parameters())
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params] non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
if training_args.optim == "adamw_torch": if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW optim_class = GaLoreAdamW
@ -168,6 +170,7 @@ def create_custom_optimzer(
"lr": training_args.learning_rate, "lr": training_args.learning_rate,
"eps": training_args.adam_epsilon, "eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2), "betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
} }
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
@ -176,6 +179,7 @@ def create_custom_optimzer(
"lr": training_args.learning_rate, "lr": training_args.learning_rate,
"eps": training_args.adam_epsilon, "eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2), "betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
"optim_bits": 8, "optim_bits": 8,
"is_paged": "paged" in training_args.optim, "is_paged": "paged" in training_args.optim,
} }
@ -184,6 +188,7 @@ def create_custom_optimzer(
optim_class = GaLoreAdafactor optim_class = GaLoreAdafactor
optim_kwargs = { optim_kwargs = {
"lr": training_args.learning_rate, "lr": training_args.learning_rate,
"weight_decay": training_args.weight_decay,
} }
else: else: