[FEATURE]: ADD LORA+ ALGORITHM

This commit is contained in:
齐保元 2024-03-13 19:43:27 +08:00
parent dfd451b722
commit a0965cd62c
4 changed files with 130 additions and 3 deletions

View File

@ -0,0 +1,33 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../data \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora_plus/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--fp16 \
--lora_lr_ratio 16.0

View File

@ -210,6 +210,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
lora_lr_ratio: Optional[float] = field(
default=None,
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
)
plot_loss: bool = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},

View File

@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer from ..utils import create_custom_optimzer, create_lora_plus_optimizer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,6 +51,8 @@ def run_sft(
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,

View File

@ -1,7 +1,8 @@
import math import math
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from transformers.trainer import Trainer
import torch import torch
from torch import nn
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -17,7 +18,7 @@ if is_galore_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments, Trainer from transformers import Seq2SeqTrainingArguments
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -244,3 +245,89 @@ def create_custom_optimzer(
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer return optimizer
def optimizer_group_callback(model, lora_lr_ratio, **defaults):
"lora plus"
params = []
names = set()
for name, param in model.named_parameters():
if "default" in name and ('lora_B' in name or
'lora_embedding_B' in name):
params.append(param)
names.add(name)
if params:
assert 'lr' in defaults
return names, {
'params': params,
'lr': defaults['lr'] * lora_lr_ratio,
}
return None, None
def create_lora_plus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.lora_lr_ratio is None:
return None
all_param_names = set()
param_groups = []
param_names, param_group = optimizer_group_callback(
model, lora_lr_ratio=finetuning_args.lora_lr_ratio,
lr=training_args.learning_rate,
weight_decay=training_args.weight_decay)
if param_names and all_param_names & param_names:
raise ValueError(
'Cannot set one parameter to different param groups')
if param_names and param_group:
all_param_names.update(param_names)
param_groups.append(param_group)
opt_model = model
decay_parameters = Trainer.get_decay_parameter_names(None, opt_model)
param_groups.extend([
{
'params': [
p for n, p in opt_model.named_parameters()
if (n in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
training_args.weight_decay,
},
{
'params': [
p for n, p in opt_model.named_parameters()
if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
0.0,
},
])
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
optimizer = optimizer_cls(param_groups, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({
p.data_ptr(): p.numel()
for p in module.parameters()
}.values())
logger.info(
f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(
module, 'weight', {'optim_bits': 32})
logger.debug(
f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
return optimizer