improve lora+ impl.

This commit is contained in:
hiyouga 2024-03-13 23:32:51 +08:00
parent 4e5e99af43
commit 72367307df
12 changed files with 165 additions and 169 deletions

View File

@ -48,7 +48,7 @@ Choose your path:
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO. - **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer. [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)

View File

@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。 - **集成方法**增量预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。 - **先进算法**GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。 [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。

View File

@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \ --lora_target q_proj,v_proj \
--output_dir ../../saves/LLaMA2-7B/lora_plus/sft \ --output_dir ../../saves/LLaMA2-7B/loraplus/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \
--cutoff_len 1024 \ --cutoff_len 1024 \
@ -30,4 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
--val_size 0.1 \ --val_size 0.1 \
--plot_loss \ --plot_loss \
--fp16 \ --fp16 \
--lora_lr_ratio 16.0 --loraplus_lr_ratio 16.0

View File

@ -57,7 +57,7 @@ class LoraArguments:
metadata={ metadata={
"help": """Name(s) of target modules to apply LoRA. \ "help": """Name(s) of target modules to apply LoRA. \
Use commas to separate multiple modules. \ Use commas to separate multiple modules. \
Use "all" to specify all the available modules. \ Use "all" to specify all the linear modules. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
@ -66,6 +66,14 @@ class LoraArguments:
Others choices: the same as LLaMA.""" Others choices: the same as LLaMA."""
}, },
) )
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
)
loraplus_lr_embedding: float = field(
default=1e-6,
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
)
use_rslora: bool = field( use_rslora: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
@ -163,8 +171,11 @@ class GaloreArguments:
metadata={"help": "Whether or not to use gradient low-Rank projection."}, metadata={"help": "Whether or not to use gradient low-Rank projection."},
) )
galore_target: str = field( galore_target: str = field(
default="mlp,attn", default="all",
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."}, metadata={
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
Use "all" to specify all the linear modules."""
},
) )
galore_rank: int = field( galore_rank: int = field(
default=16, default=16,
@ -210,11 +221,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
default=False, default=False,
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
) )
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
lora_lr_ratio: Optional[float] = field(
default=None,
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
)
plot_loss: bool = field( plot_loss: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to save the training loss curves."}, metadata={"help": "Whether or not to save the training loss curves."},
@ -230,6 +236,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
self.lora_alpha = self.lora_alpha or self.lora_rank * 2 self.lora_alpha = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target) self.lora_target = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target) self.additional_target = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target)
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."

View File

@ -1,5 +1,5 @@
from .loader import load_model, load_model_and_tokenizer, load_tokenizer from .loader import load_model, load_model_and_tokenizer, load_tokenizer
from .utils import load_valuehead_params from .utils import find_all_linear_modules, load_valuehead_params
__all__ = [ __all__ = [
@ -7,4 +7,5 @@ __all__ = [
"load_model_and_tokenizer", "load_model_and_tokenizer",
"load_tokenizer", "load_tokenizer",
"load_valuehead_params", "load_valuehead_params",
"find_all_linear_modules",
] ]

View File

@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import find_all_linear_modules, find_expanded_modules from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
if TYPE_CHECKING: if TYPE_CHECKING:
@ -129,9 +129,9 @@ def init_adapter(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
if finetuning_args.use_dora: if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
if getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
raise ValueError("DoRA is currently not compatible with quantized models.") raise ValueError("DoRA is not compatible with PTQ-quantized models.")
peft_kwargs = { peft_kwargs = {
"r": finetuning_args.lora_rank, "r": finetuning_args.lora_rank,

View File

@ -109,10 +109,6 @@ def load_model(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) model.requires_grad_(False)
if not getattr(model, "quantization_method", None):
for param in filter(lambda p: p.device.type == "cuda", model.parameters()):
param.data = param.data.to(model_args.compute_dtype)
model.eval() model.eval()
else: else:
model.train() model.train()

View File

@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
from .utils import QuantizationMethod
if TYPE_CHECKING: if TYPE_CHECKING:
@ -173,10 +174,10 @@ def _configure_quantization(
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "") quant_method = quantization_config.get("quant_method", "")
if quant_method == "gptq": if quant_method == QuantizationMethod.GPTQ:
quantization_config["use_exllama"] = False # disable exllama quantization_config["use_exllama"] = False # disable exllama
if quant_method == "aqlm": if quant_method == QuantizationMethod.AQLM:
require_version( require_version(
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" "transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
) )
@ -205,7 +206,7 @@ def _configure_quantization(
elif model_args.quantization_bit is not None: # bnb elif model_args.quantization_bit is not None: # bnb
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")

View File

@ -1,3 +1,4 @@
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List from typing import TYPE_CHECKING, Dict, List
import torch import torch
@ -17,6 +18,18 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
@unique
class QuantizationMethod(str, Enum):
r"""
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
"""
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
r""" r"""
Finds all available modules to apply lora. Finds all available modules to apply lora.
@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
quantization_method = getattr(model, "quantization_method", None) quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None: if quantization_method is None:
linear_cls = torch.nn.Linear linear_cls = torch.nn.Linear
elif quantization_method == "bitsandbytes": elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb import bitsandbytes as bnb
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt

View File

@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer, create_lora_plus_optimizer from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,8 +51,6 @@ def run_sft(
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if finetuning_args.lora_lr_ratio:
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,

View File

@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None): def export_model(args: Optional[Dict[str, Any]] = None):
model_args, data_args, finetuning_args, _ = get_infer_args(args) model_args, data_args, finetuning_args, _ = get_infer_args(args)
model_args.device_map = {"": "cpu"}
if model_args.export_dir is None: if model_args.export_dir is None:
raise ValueError("Please specify `export_dir`.") raise ValueError("Please specify `export_dir` to save model.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.") raise ValueError("Please merge adapters before quantizing the model.")
@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel): if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.") raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
model = model.to("cpu") output_dtype = getattr(model.config, "torch_dtype", torch.float16)
elif hasattr(model.config, "torch_dtype"): model = model.to(output_dtype)
model = model.to(getattr(model.config, "torch_dtype")).to("cpu") setattr(model.config, "torch_dtype", output_dtype)
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,

View File

@ -1,15 +1,17 @@
import math import math
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from transformers.trainer import Trainer
import torch import torch
from torch import nn from transformers import Trainer
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params
if is_galore_available(): if is_galore_available():
@ -29,9 +31,10 @@ logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer): class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, *args, **kwargs): def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None:
dummy_tensor = torch.randn(1, 1) dummy_tensor = torch.randn(1, 1)
super().__init__([dummy_tensor], {"lr": 1e-3}) self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
def zero_grad(self, set_to_none: bool = True) -> None: def zero_grad(self, set_to_none: bool = True) -> None:
pass pass
@ -142,59 +145,33 @@ def create_reward_model(
return reward_model return reward_model
def create_custom_optimzer( def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def _create_galore_optimizer(
model: "PreTrainedModel", model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"], dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]: ) -> "torch.optim.Optimizer":
if not finetuning_args.use_galore:
return None
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git") require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
galore_params: List[torch.nn.Parameter] = []
galore_targets = finetuning_args.galore_target.split(",")
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
galore_targets = find_all_linear_modules(model)
galore_params: List["torch.nn.Parameter"] = []
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
for param in module.parameters(): for param in module.parameters():
if param.requires_grad and len(param.shape) > 1: if param.requires_grad and len(param.shape) > 1:
galore_params.append(param) galore_params.append(param)
id_galore_params = {id(param) for param in galore_params}
trainable_params = filter(lambda param: param.requires_grad, model.parameters())
non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
}
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"weight_decay": training_args.weight_decay,
"optim_bits": 8,
"is_paged": "paged" in training_args.optim,
}
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
optim_kwargs = {
"lr": training_args.learning_rate,
"weight_decay": training_args.weight_decay,
}
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
galore_kwargs = { galore_kwargs = {
"rank": finetuning_args.galore_rank, "rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval, "update_proj_gap": finetuning_args.galore_update_interval,
@ -202,6 +179,30 @@ def create_custom_optimzer(
"proj_type": finetuning_args.galore_proj_type, "proj_type": finetuning_args.galore_proj_type,
} }
id_galore_params = {id(param) for param in galore_params}
decay_params, nodecay_params = [], [] # they are non-galore parameters
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
decay_param_names = _get_decay_parameter_names(model)
for name, param in model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
if id(param) not in id_galore_params:
if name in decay_param_names:
decay_params.append(param)
else:
nodecay_params.append(param)
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
if training_args.optim == "adamw_torch":
optim_class = GaLoreAdamW
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optim_class = GaLoreAdamW8bit
elif training_args.optim == "adafactor":
optim_class = GaLoreAdafactor
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
if finetuning_args.galore_layerwise: if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1: if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.") raise ValueError("Per-layer GaLore does not support gradient accumulation.")
@ -213,15 +214,18 @@ def create_custom_optimzer(
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in non_galore_params: for param in nodecay_params:
param_groups = [dict(params=[param])] param_groups = [dict(params=[param])]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in decay_params:
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params: for param in galore_params:
param_groups = [dict(params=[param], **galore_kwargs)] param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {} scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in non_galore_params + galore_params: for param in trainable_params:
scheduler_dict[param] = get_scheduler( scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer_dict[param], optimizer=optimizer_dict[param],
@ -235,99 +239,72 @@ def create_custom_optimzer(
optimizer_dict[param].zero_grad() optimizer_dict[param].zero_grad()
scheduler_dict[param].step() scheduler_dict[param].step()
for param in non_galore_params + galore_params: for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook) param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer() optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result
else: else:
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)] param_groups = [
dict(params=nodecay_params),
dict(params=decay_params, weight_decay=training_args.weight_decay),
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
]
optimizer = optim_class(param_groups, **optim_kwargs) optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer return optimizer
def optimizer_group_callback(model, lora_lr_ratio, **defaults): def _create_loraplus_optimizer(
"lora plus" model: "PreTrainedModel",
params = [] dataset: Union["Dataset", "IterableDataset"],
names = set() training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> "torch.optim.Optimizer":
if finetuning_args.finetuning_type != "lora":
raise ValueError("You should use LoRA tuning to activate LoRA+.")
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
decay_args = {"weight_decay": training_args.weight_decay}
decay_param_names = _get_decay_parameter_names(model)
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
"lora_a": [],
"lora_b": [],
"lora_b_nodecay": [],
"embedding": [],
}
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if "default" in name and ('lora_B' in name or if param.requires_grad:
'lora_embedding_B' in name): if "lora_embedding_B" in name:
params.append(param) param_dict["embedding"].append(param)
names.add(name) elif "lora_B" in name or param.ndim == 1:
if params: if name in decay_param_names:
assert 'lr' in defaults param_dict["lora_b"].append(param)
return names, { else:
'params': params, param_dict["lora_b_nodecay"].append(param)
'lr': defaults['lr'] * lora_lr_ratio, else:
} param_dict["lora_a"].append(param)
return None, None
def create_lora_plus_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if finetuning_args.lora_lr_ratio is None:
return None
all_param_names = set()
param_groups = []
param_names, param_group = optimizer_group_callback(
model, lora_lr_ratio=finetuning_args.lora_lr_ratio,
lr=training_args.learning_rate,
weight_decay=training_args.weight_decay)
if param_names and all_param_names & param_names:
raise ValueError(
'Cannot set one parameter to different param groups')
if param_names and param_group:
all_param_names.update(param_names)
param_groups.append(param_group)
opt_model = model
decay_parameters = Trainer.get_decay_parameter_names(None, opt_model)
param_groups.extend([
{
'params': [
p for n, p in opt_model.named_parameters()
if (n in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
training_args.weight_decay,
},
{
'params': [
p for n, p in opt_model.named_parameters()
if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
],
'weight_decay':
0.0,
},
])
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
optimizer = optimizer_cls(param_groups, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({
p.data_ptr(): p.numel()
for p in module.parameters()
}.values())
logger.info(
f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(
module, 'weight', {'optim_bits': 32})
logger.debug(
f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
param_groups = [
dict(params=param_dict["lora_a"], **decay_args),
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
]
optimizer = optim_class(param_groups, **optim_kwargs)
return optimizer return optimizer
def create_custom_optimzer(
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
if finetuning_args.loraplus_lr_ratio is not None:
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)