fix optimizers
This commit is contained in:
parent
a1d31ffc8c
commit
fbbe0dba2f
|
@ -303,11 +303,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
|
|
|
@ -162,15 +162,6 @@ def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
|||
return decay_parameters
|
||||
|
||||
|
||||
def _get_embedding_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters in embedding.
|
||||
"""
|
||||
result = {name for name, _ in model.get_input_embeddings().named_parameters()}
|
||||
result.update(name for name, _ in model.get_output_embeddings().named_parameters())
|
||||
return result
|
||||
|
||||
|
||||
def _create_galore_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
|
@ -225,7 +216,7 @@ def _create_galore_optimizer(
|
|||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
|
@ -234,6 +225,14 @@ def _create_galore_optimizer(
|
|||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in trainable_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
|
@ -252,11 +251,9 @@ def _create_loraplus_optimizer(
|
|||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("You should use LoRA tuning to activate LoRA+.")
|
||||
|
||||
default_lr = training_args.learning_rate
|
||||
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
|
||||
decay_args = {"weight_decay": training_args.weight_decay}
|
||||
embedding_lr = finetuning_args.loraplus_lr_embedding
|
||||
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||
|
@ -279,10 +276,10 @@ def _create_loraplus_optimizer(
|
|||
|
||||
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
param_groups = [
|
||||
dict(params=param_dict["lora_a"], **decay_args),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
||||
dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay),
|
||||
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0),
|
||||
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||
|
@ -294,11 +291,8 @@ def _create_badam_optimizer(
|
|||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
if finetuning_args.badam_mode == "ratio": # filter out the embedding layers for ratio-wise badam
|
||||
decay_param_names = [name for name in decay_param_names if name not in _get_embedding_names(model)]
|
||||
|
||||
decay_params, nodecay_params = [], []
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name in decay_param_names:
|
||||
|
@ -341,6 +335,7 @@ def _create_badam_optimizer(
|
|||
update_ratio=finetuning_args.badam_update_ratio,
|
||||
mask_mode=finetuning_args.badam_mask_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
include_embedding=False,
|
||||
**optim_kwargs,
|
||||
)
|
||||
logger.info(
|
||||
|
@ -379,15 +374,12 @@ def create_custom_scheduler(
|
|||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
scheduler_dict[param].step()
|
||||
def scheduler_hook(param: "torch.nn.Parameter"):
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in optimizer_dict.keys():
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
param.register_post_accumulate_grad_hook(scheduler_hook)
|
||||
|
|
Loading…
Reference in New Issue