improve lora+ impl.
This commit is contained in:
parent
4e5e99af43
commit
72367307df
|
@ -48,7 +48,7 @@ Choose your path:
|
||||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoftQ and Agent tuning.
|
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
|
||||||
|
|
||||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||||
|
|
||||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
|
@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||||
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
||||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||||
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
- **先进算法**:GaLore、DoRA、LongLoRA、LLaMA Pro、LoRA+、LoftQ 和 Agent 微调。
|
||||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||||
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
|
||||||
|
@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。
|
||||||
|
|
||||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||||
|
|
||||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
|
@ -9,7 +9,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--lora_target q_proj,v_proj \
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir ../../saves/LLaMA2-7B/lora_plus/sft \
|
--output_dir ../../saves/LLaMA2-7B/loraplus/sft \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--overwrite_output_dir \
|
--overwrite_output_dir \
|
||||||
--cutoff_len 1024 \
|
--cutoff_len 1024 \
|
||||||
|
@ -30,4 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||||
--val_size 0.1 \
|
--val_size 0.1 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16 \
|
--fp16 \
|
||||||
--lora_lr_ratio 16.0
|
--loraplus_lr_ratio 16.0
|
|
@ -57,7 +57,7 @@ class LoraArguments:
|
||||||
metadata={
|
metadata={
|
||||||
"help": """Name(s) of target modules to apply LoRA. \
|
"help": """Name(s) of target modules to apply LoRA. \
|
||||||
Use commas to separate multiple modules. \
|
Use commas to separate multiple modules. \
|
||||||
Use "all" to specify all the available modules. \
|
Use "all" to specify all the linear modules. \
|
||||||
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
|
||||||
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
|
||||||
|
@ -66,6 +66,14 @@ class LoraArguments:
|
||||||
Others choices: the same as LLaMA."""
|
Others choices: the same as LLaMA."""
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: float = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
use_rslora: bool = field(
|
use_rslora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||||
|
@ -163,8 +171,11 @@ class GaloreArguments:
|
||||||
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||||
)
|
)
|
||||||
galore_target: str = field(
|
galore_target: str = field(
|
||||||
default="mlp,attn",
|
default="all",
|
||||||
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
|
metadata={
|
||||||
|
"help": """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
|
||||||
|
Use "all" to specify all the linear modules."""
|
||||||
|
},
|
||||||
)
|
)
|
||||||
galore_rank: int = field(
|
galore_rank: int = field(
|
||||||
default=16,
|
default=16,
|
||||||
|
@ -210,11 +221,6 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||||
)
|
)
|
||||||
# for lora+,[LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
|
|
||||||
lora_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={'help': 'The lora learning_rate ratio of lora_A to lora_B, option:16.0.'},
|
|
||||||
)
|
|
||||||
plot_loss: bool = field(
|
plot_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
|
@ -230,6 +236,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target = split_arg(self.lora_target)
|
self.lora_target = split_arg(self.lora_target)
|
||||||
self.additional_target = split_arg(self.additional_target)
|
self.additional_target = split_arg(self.additional_target)
|
||||||
|
self.galore_target = split_arg(self.galore_target)
|
||||||
|
|
||||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
from .loader import load_model, load_model_and_tokenizer, load_tokenizer
|
||||||
from .utils import load_valuehead_params
|
from .utils import find_all_linear_modules, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -7,4 +7,5 @@ __all__ = [
|
||||||
"load_model_and_tokenizer",
|
"load_model_and_tokenizer",
|
||||||
"load_tokenizer",
|
"load_tokenizer",
|
||||||
"load_valuehead_params",
|
"load_valuehead_params",
|
||||||
|
"find_all_linear_modules",
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .utils import find_all_linear_modules, find_expanded_modules
|
from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -129,9 +129,9 @@ def init_adapter(
|
||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
if finetuning_args.use_dora:
|
if finetuning_args.use_dora and getattr(model, "quantization_method", None) is not None:
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES:
|
||||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
|
|
|
@ -109,10 +109,6 @@ def load_model(
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
if not getattr(model, "quantization_method", None):
|
|
||||||
for param in filter(lambda p: p.device.type == "cuda", model.parameters()):
|
|
||||||
param.data = param.data.to(model_args.compute_dtype)
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
|
@ -18,6 +18,7 @@ from ..extras.misc import get_current_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_flash_attn2_available
|
from ..extras.packages import is_flash_attn2_available
|
||||||
from ..extras.patches.llama_patch import apply_llama_patch
|
from ..extras.patches.llama_patch import apply_llama_patch
|
||||||
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
|
||||||
|
from .utils import QuantizationMethod
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -173,10 +174,10 @@ def _configure_quantization(
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
quant_method = quantization_config.get("quant_method", "")
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
|
||||||
if quant_method == "gptq":
|
if quant_method == QuantizationMethod.GPTQ:
|
||||||
quantization_config["use_exllama"] = False # disable exllama
|
quantization_config["use_exllama"] = False # disable exllama
|
||||||
|
|
||||||
if quant_method == "aqlm":
|
if quant_method == QuantizationMethod.AQLM:
|
||||||
require_version(
|
require_version(
|
||||||
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
"transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||||
)
|
)
|
||||||
|
@ -205,7 +206,7 @@ def _configure_quantization(
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,6 +18,18 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class QuantizationMethod(str, Enum):
|
||||||
|
r"""
|
||||||
|
Borrowed from `transformers.utils.quantization_config.QuantizationMethod`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
BITS_AND_BYTES = "bitsandbytes"
|
||||||
|
GPTQ = "gptq"
|
||||||
|
AWQ = "awq"
|
||||||
|
AQLM = "aqlm"
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
r"""
|
r"""
|
||||||
Finds all available modules to apply lora.
|
Finds all available modules to apply lora.
|
||||||
|
@ -24,7 +37,7 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||||
quantization_method = getattr(model, "quantization_method", None)
|
quantization_method = getattr(model, "quantization_method", None)
|
||||||
if quantization_method is None:
|
if quantization_method is None:
|
||||||
linear_cls = torch.nn.Linear
|
linear_cls = torch.nn.Linear
|
||||||
elif quantization_method == "bitsandbytes":
|
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||||
from ...train.sft.metric import ComputeMetrics
|
from ...train.sft.metric import ComputeMetrics
|
||||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
from ..utils import create_custom_optimzer, create_lora_plus_optimizer
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -51,8 +51,6 @@ def run_sft(
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||||
if finetuning_args.lora_lr_ratio:
|
|
||||||
optimizer = create_lora_plus_optimizer(model, training_args, finetuning_args)
|
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|
|
@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||||
|
|
||||||
|
model_args.device_map = {"": "cpu"}
|
||||||
|
|
||||||
if model_args.export_dir is None:
|
if model_args.export_dir is None:
|
||||||
raise ValueError("Please specify `export_dir`.")
|
raise ValueError("Please specify `export_dir` to save model.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||||
raise ValueError("Please merge adapters before quantizing the model.")
|
raise ValueError("Please merge adapters before quantizing the model.")
|
||||||
|
@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||||
if not isinstance(model, PreTrainedModel):
|
if not isinstance(model, PreTrainedModel):
|
||||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
||||||
model = model.to("cpu")
|
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||||
elif hasattr(model.config, "torch_dtype"):
|
model = model.to(output_dtype)
|
||||||
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
|
setattr(model.config, "torch_dtype", output_dtype)
|
||||||
else:
|
|
||||||
model = model.to(torch.float16).to("cpu")
|
|
||||||
setattr(model.config, "torch_dtype", torch.float16)
|
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||||
from transformers.trainer import Trainer
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from transformers import Trainer
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
if is_galore_available():
|
if is_galore_available():
|
||||||
|
@ -29,9 +31,10 @@ logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DummyOptimizer(torch.optim.Optimizer):
|
class DummyOptimizer(torch.optim.Optimizer):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None:
|
||||||
dummy_tensor = torch.randn(1, 1)
|
dummy_tensor = torch.randn(1, 1)
|
||||||
super().__init__([dummy_tensor], {"lr": 1e-3})
|
self.optimizer_dict = optimizer_dict
|
||||||
|
super().__init__([dummy_tensor], {"lr": lr})
|
||||||
|
|
||||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -142,59 +145,33 @@ def create_reward_model(
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
def create_custom_optimzer(
|
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||||
|
r"""
|
||||||
|
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||||
|
"""
|
||||||
|
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
|
||||||
|
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||||
|
return decay_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def _create_galore_optimizer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> Optional["torch.optim.Optimizer"]:
|
) -> "torch.optim.Optimizer":
|
||||||
if not finetuning_args.use_galore:
|
|
||||||
return None
|
|
||||||
|
|
||||||
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
||||||
galore_params: List[torch.nn.Parameter] = []
|
|
||||||
galore_targets = finetuning_args.galore_target.split(",")
|
|
||||||
|
|
||||||
|
if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all":
|
||||||
|
galore_targets = find_all_linear_modules(model)
|
||||||
|
|
||||||
|
galore_params: List["torch.nn.Parameter"] = []
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
if param.requires_grad and len(param.shape) > 1:
|
if param.requires_grad and len(param.shape) > 1:
|
||||||
galore_params.append(param)
|
galore_params.append(param)
|
||||||
|
|
||||||
id_galore_params = {id(param) for param in galore_params}
|
|
||||||
trainable_params = filter(lambda param: param.requires_grad, model.parameters())
|
|
||||||
non_galore_params = [param for param in trainable_params if id(param) not in id_galore_params]
|
|
||||||
|
|
||||||
if training_args.optim == "adamw_torch":
|
|
||||||
optim_class = GaLoreAdamW
|
|
||||||
optim_kwargs = {
|
|
||||||
"lr": training_args.learning_rate,
|
|
||||||
"eps": training_args.adam_epsilon,
|
|
||||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
|
||||||
"weight_decay": training_args.weight_decay,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
|
||||||
optim_class = GaLoreAdamW8bit
|
|
||||||
optim_kwargs = {
|
|
||||||
"lr": training_args.learning_rate,
|
|
||||||
"eps": training_args.adam_epsilon,
|
|
||||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
|
||||||
"weight_decay": training_args.weight_decay,
|
|
||||||
"optim_bits": 8,
|
|
||||||
"is_paged": "paged" in training_args.optim,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif training_args.optim == "adafactor":
|
|
||||||
optim_class = GaLoreAdafactor
|
|
||||||
optim_kwargs = {
|
|
||||||
"lr": training_args.learning_rate,
|
|
||||||
"weight_decay": training_args.weight_decay,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
|
||||||
|
|
||||||
galore_kwargs = {
|
galore_kwargs = {
|
||||||
"rank": finetuning_args.galore_rank,
|
"rank": finetuning_args.galore_rank,
|
||||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||||
|
@ -202,6 +179,30 @@ def create_custom_optimzer(
|
||||||
"proj_type": finetuning_args.galore_proj_type,
|
"proj_type": finetuning_args.galore_proj_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id_galore_params = {id(param) for param in galore_params}
|
||||||
|
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||||
|
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||||
|
decay_param_names = _get_decay_parameter_names(model)
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
trainable_params.append(param)
|
||||||
|
if id(param) not in id_galore_params:
|
||||||
|
if name in decay_param_names:
|
||||||
|
decay_params.append(param)
|
||||||
|
else:
|
||||||
|
nodecay_params.append(param)
|
||||||
|
|
||||||
|
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
|
|
||||||
|
if training_args.optim == "adamw_torch":
|
||||||
|
optim_class = GaLoreAdamW
|
||||||
|
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||||
|
optim_class = GaLoreAdamW8bit
|
||||||
|
elif training_args.optim == "adafactor":
|
||||||
|
optim_class = GaLoreAdafactor
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||||
|
|
||||||
if finetuning_args.galore_layerwise:
|
if finetuning_args.galore_layerwise:
|
||||||
if training_args.gradient_accumulation_steps != 1:
|
if training_args.gradient_accumulation_steps != 1:
|
||||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||||
|
@ -213,15 +214,18 @@ def create_custom_optimzer(
|
||||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
|
|
||||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||||
for param in non_galore_params:
|
for param in nodecay_params:
|
||||||
param_groups = [dict(params=[param])]
|
param_groups = [dict(params=[param])]
|
||||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||||
|
for param in decay_params:
|
||||||
|
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||||
|
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||||
for param in galore_params:
|
for param in galore_params:
|
||||||
param_groups = [dict(params=[param], **galore_kwargs)]
|
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)]
|
||||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||||
|
|
||||||
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||||
for param in non_galore_params + galore_params:
|
for param in trainable_params:
|
||||||
scheduler_dict[param] = get_scheduler(
|
scheduler_dict[param] = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer_dict[param],
|
optimizer=optimizer_dict[param],
|
||||||
|
@ -235,99 +239,72 @@ def create_custom_optimzer(
|
||||||
optimizer_dict[param].zero_grad()
|
optimizer_dict[param].zero_grad()
|
||||||
scheduler_dict[param].step()
|
scheduler_dict[param].step()
|
||||||
|
|
||||||
for param in non_galore_params + galore_params:
|
for param in trainable_params:
|
||||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||||
|
|
||||||
optimizer = DummyOptimizer()
|
optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result
|
||||||
else:
|
else:
|
||||||
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
|
param_groups = [
|
||||||
|
dict(params=nodecay_params),
|
||||||
|
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||||
|
dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs),
|
||||||
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
|
|
||||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def optimizer_group_callback(model, lora_lr_ratio, **defaults):
|
def _create_loraplus_optimizer(
|
||||||
"lora plus"
|
model: "PreTrainedModel",
|
||||||
params = []
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
names = set()
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
if finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("You should use LoRA tuning to activate LoRA+.")
|
||||||
|
|
||||||
|
loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio
|
||||||
|
decay_args = {"weight_decay": training_args.weight_decay}
|
||||||
|
|
||||||
|
decay_param_names = _get_decay_parameter_names(model)
|
||||||
|
param_dict: Dict[str, List["torch.nn.Parameter"]] = {
|
||||||
|
"lora_a": [],
|
||||||
|
"lora_b": [],
|
||||||
|
"lora_b_nodecay": [],
|
||||||
|
"embedding": [],
|
||||||
|
}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if "default" in name and ('lora_B' in name or
|
if param.requires_grad:
|
||||||
'lora_embedding_B' in name):
|
if "lora_embedding_B" in name:
|
||||||
params.append(param)
|
param_dict["embedding"].append(param)
|
||||||
names.add(name)
|
elif "lora_B" in name or param.ndim == 1:
|
||||||
if params:
|
if name in decay_param_names:
|
||||||
assert 'lr' in defaults
|
param_dict["lora_b"].append(param)
|
||||||
return names, {
|
else:
|
||||||
'params': params,
|
param_dict["lora_b_nodecay"].append(param)
|
||||||
'lr': defaults['lr'] * lora_lr_ratio,
|
else:
|
||||||
}
|
param_dict["lora_a"].append(param)
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def create_lora_plus_optimizer(
|
|
||||||
model: "PreTrainedModel",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
) -> Optional["torch.optim.Optimizer"]:
|
|
||||||
if finetuning_args.lora_lr_ratio is None:
|
|
||||||
return None
|
|
||||||
all_param_names = set()
|
|
||||||
param_groups = []
|
|
||||||
param_names, param_group = optimizer_group_callback(
|
|
||||||
model, lora_lr_ratio=finetuning_args.lora_lr_ratio,
|
|
||||||
lr=training_args.learning_rate,
|
|
||||||
weight_decay=training_args.weight_decay)
|
|
||||||
if param_names and all_param_names & param_names:
|
|
||||||
raise ValueError(
|
|
||||||
'Cannot set one parameter to different param groups')
|
|
||||||
if param_names and param_group:
|
|
||||||
all_param_names.update(param_names)
|
|
||||||
param_groups.append(param_group)
|
|
||||||
|
|
||||||
opt_model = model
|
|
||||||
decay_parameters = Trainer.get_decay_parameter_names(None, opt_model)
|
|
||||||
param_groups.extend([
|
|
||||||
{
|
|
||||||
'params': [
|
|
||||||
p for n, p in opt_model.named_parameters()
|
|
||||||
if (n in decay_parameters and n not in all_param_names and p.requires_grad)
|
|
||||||
],
|
|
||||||
'weight_decay':
|
|
||||||
training_args.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'params': [
|
|
||||||
p for n, p in opt_model.named_parameters()
|
|
||||||
if (n not in decay_parameters and n not in all_param_names and p.requires_grad)
|
|
||||||
],
|
|
||||||
'weight_decay':
|
|
||||||
0.0,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
|
||||||
|
|
||||||
optimizer = optimizer_cls(param_groups, **optimizer_kwargs)
|
|
||||||
|
|
||||||
if optimizer_cls.__name__ == 'Adam8bit':
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum({
|
|
||||||
p.data_ptr(): p.numel()
|
|
||||||
for p in module.parameters()
|
|
||||||
}.values())
|
|
||||||
logger.info(
|
|
||||||
f'skipped {module}: {skipped / 2 ** 20}M params')
|
|
||||||
manager.register_module_override(
|
|
||||||
module, 'weight', {'optim_bits': 32})
|
|
||||||
logger.debug(
|
|
||||||
f'bitsandbytes: will optimize {module} in fp32')
|
|
||||||
logger.info(f'skipped: {skipped / 2 ** 20}M params')
|
|
||||||
|
|
||||||
|
optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
|
param_groups = [
|
||||||
|
dict(params=param_dict["lora_a"], **decay_args),
|
||||||
|
dict(params=param_dict["lora_b"], lr=loraplus_lr, **decay_args),
|
||||||
|
dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr),
|
||||||
|
dict(params=param_dict["embedding"], lr=finetuning_args.loraplus_lr_embedding, **decay_args),
|
||||||
|
]
|
||||||
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_optimzer(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
|
if not finetuning_args.use_galore:
|
||||||
|
return _create_galore_optimizer(model, dataset, training_args, finetuning_args)
|
||||||
|
|
||||||
|
if finetuning_args.loraplus_lr_ratio is not None:
|
||||||
|
return _create_loraplus_optimizer(model, dataset, training_args, finetuning_args)
|
||||||
|
|
Loading…
Reference in New Issue