fix #2775
This commit is contained in:
parent
352693e2dc
commit
e874c00906
|
@ -1,5 +1,5 @@
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
@ -151,16 +151,18 @@ def create_custom_optimzer(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
||||||
galore_params = []
|
galore_params: List[torch.nn.Parameter] = []
|
||||||
galore_targets = finetuning_args.galore_target.split(",")
|
galore_targets = finetuning_args.galore_target.split(",")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||||
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
|
for param in module.parameters():
|
||||||
|
if param.requires_grad and len(param.shape) > 1:
|
||||||
|
galore_params.append(param)
|
||||||
|
|
||||||
id_galore_params = [id(p) for p in galore_params]
|
id_galore_params = {id(param) for param in galore_params}
|
||||||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
trainable_params = filter(lambda param: param.requires_grad, model.parameters())
|
||||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
|
||||||
|
|
||||||
if training_args.optim == "adamw_torch":
|
if training_args.optim == "adamw_torch":
|
||||||
optim_class = GaLoreAdamW
|
optim_class = GaLoreAdamW
|
||||||
|
@ -168,6 +170,7 @@ def create_custom_optimzer(
|
||||||
"lr": training_args.learning_rate,
|
"lr": training_args.learning_rate,
|
||||||
"eps": training_args.adam_epsilon,
|
"eps": training_args.adam_epsilon,
|
||||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||||
|
"weight_decay": training_args.weight_decay,
|
||||||
}
|
}
|
||||||
|
|
||||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||||
|
@ -176,6 +179,7 @@ def create_custom_optimzer(
|
||||||
"lr": training_args.learning_rate,
|
"lr": training_args.learning_rate,
|
||||||
"eps": training_args.adam_epsilon,
|
"eps": training_args.adam_epsilon,
|
||||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||||
|
"weight_decay": training_args.weight_decay,
|
||||||
"optim_bits": 8,
|
"optim_bits": 8,
|
||||||
"is_paged": "paged" in training_args.optim,
|
"is_paged": "paged" in training_args.optim,
|
||||||
}
|
}
|
||||||
|
@ -184,6 +188,7 @@ def create_custom_optimzer(
|
||||||
optim_class = GaLoreAdafactor
|
optim_class = GaLoreAdafactor
|
||||||
optim_kwargs = {
|
optim_kwargs = {
|
||||||
"lr": training_args.learning_rate,
|
"lr": training_args.learning_rate,
|
||||||
|
"weight_decay": training_args.weight_decay,
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue