[FEATURE]: ADD LORA+ ALGORITHM
This commit is contained in:
parent
dfd451b722
commit
a0965cd62c
|
@ -0,0 +1,33 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-7B/lora_plus/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16 \
|
||||||
|
--lora_lr_ratio 16.0
|
|
@ -210,6 +210,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
|
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
|
||||||
|
lora_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
|
||||||
|
)
|
||||||
plot_loss: bool = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||||
from ...train.sft.metric import ComputeMetrics
|
from ...train.sft.metric import ComputeMetrics
|
||||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
from ..utils import create_custom_optimzer
|
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -51,6 +51,8 @@ def run_sft(
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||||
|
if finetuning_args.lora_lr_ratio:
|
||||||
|
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||||
|
from transformers.trainer import Trainer
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ if is_galore_available():
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
@ -244,3 +245,89 @@ def create_custom_optimzer(
|
||||||
|
|
||||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer_group_callback(model, lora_lr_ratio, **defaults):
|
||||||
|
"lora plus"
|
||||||
|
params = []
|
||||||
|
names = set()
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if "default" in name and ('lora_B' in name or
|
||||||
|
'lora_embedding_B' in name):
|
||||||
|
params.append(param)
|
||||||
|
names.add(name)
|
||||||
|
if params:
|
||||||
|
assert 'lr' in defaults
|
||||||
|
return names, {
|
||||||
|
'params': params,
|
||||||
|
'lr': defaults['lr'] * lora_lr_ratio,
|
||||||
|
}
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_plus_optimizer(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
|
if finetuning_args.lora_lr_ratio is None:
|
||||||
|
return None
|
||||||
|
all_param_names = set()
|
||||||
|
param_groups = []
|
||||||
|
param_names, param_group = optimizer_group_callback(
|
||||||
|
model, lora_lr_ratio=finetuning_args.lora_lr_ratio,
|
||||||
|
lr=training_args.learning_rate,
|
||||||
|
weight_decay=training_args.weight_decay)
|
||||||
|
if param_names and all_param_names & param_names:
|
||||||
|
raise ValueError(
|
||||||
|
'Cannot set one parameter to different param groups')
|
||||||
|
if param_names and param_group:
|
||||||
|
all_param_names.update(param_names)
|
||||||
|
param_groups.append(param_group)
|
||||||
|
|
||||||
|
opt_model = model
|
||||||
|
decay_parameters = Trainer.get_decay_parameter_names(None, opt_model)
|
||||||
|
param_groups.extend([
|
||||||
|
{
|
||||||
|
'params': [
|
||||||
|
p for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and n not in all_param_names and p.requires_grad)
|
||||||
|
],
|
||||||
|
'weight_decay':
|
||||||
|
training_args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [
|
||||||
|
p for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
|
||||||
|
],
|
||||||
|
'weight_decay':
|
||||||
|
0.0,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
|
|
||||||
|
optimizer = optimizer_cls(param_groups, **optimizer_kwargs)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == 'Adam8bit':
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum({
|
||||||
|
p.data_ptr(): p.numel()
|
||||||
|
for p in module.parameters()
|
||||||
|
}.values())
|
||||||
|
logger.info(
|
||||||
|
f'skipped {module}: {skipped / 2 ** 20}M params')
|
||||||
|
manager.register_module_override(
|
||||||
|
module, 'weight', {'optim_bits': 32})
|
||||||
|
logger.debug(
|
||||||
|
f'bitsandbytes: will optimize {module} in fp32')
|
||||||
|
logger.info(f'skipped: {skipped / 2 ** 20}M params')
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
Loading…
Reference in New Issue