fix galore

This commit is contained in:
hiyouga 2024-03-08 00:44:51 +08:00
parent 57452a4aa1
commit 33a4c24a8a
11 changed files with 129 additions and 25 deletions

View File

@ -70,7 +70,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)

View File

@ -70,7 +70,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 算法。请使用 `--use_galore` 参数切换显存高效的优化器。
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**算法。请使用 `--use_galore` 参数切换显存高效的优化器。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。

View File

@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type freeze \
--name_module_trainable mlp,self_attn \
--num_layer_trainable 8 \
--finetuning_type full \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \

View File

@ -0,0 +1,32 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type freeze \
--name_module_trainable mlp,self_attn \
--num_layer_trainable 8 \
--finetuning_type full \
--use_galore \
--galore_target mlp,self_attn \
--galore_rank 32 \

View File

@ -0,0 +1,35 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_galore \
--galore_target mlp,self_attn \
--galore_rank 32 \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -18,6 +18,19 @@ def get_requires():
return lines
extra_require = {
"deepspeed": ["deepspeed==0.13.1"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"],
"vllm": ["vllm==0.3.3"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu,cpu]"],
"galore": ["galore_torch @ git+https://github.com/jiaweizzhao/GaLore.git"],
}
def main():
setup(
@ -35,6 +48,7 @@ def main():
packages=find_packages("src"),
python_requires=">=3.8.0",
install_requires=get_requires(),
extras_require=extra_require,
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",

View File

@ -66,10 +66,6 @@ class LoraArguments:
Others choices: the same as LLaMA."""
},
)
lora_bf16_mode: bool = field(
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: bool = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
@ -194,6 +190,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."},

View File

@ -7,6 +7,7 @@ import torch
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from ..extras.logging import get_logger
from ..extras.misc import check_dependencies
@ -156,6 +157,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
if training_args.fp16 or training_args.bf16:
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
_verify_model_args(model_args, finetuning_args)
if (
@ -226,9 +234,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
if training_args.bf16 or finetuning_args.pure_bf16:
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate

View File

@ -34,6 +34,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full")
if not finetuning_args.pure_bf16:
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable:
@ -78,6 +79,7 @@ def init_adapter(
for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers):
if not finetuning_args.pure_bf16:
param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
@ -150,8 +152,9 @@ def init_adapter(
)
model = get_peft_model(model, lora_config)
if not finetuning_args.pure_bf16:
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@ -154,14 +154,28 @@ def create_custom_optimzer(
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adamw_8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
optimizer = GaLoreAdamW(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
)
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optimizer = GaLoreAdamW8bit(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
optim_bits=8,
is_paged="paged" in training_args.optim,
)
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
optimizer = GaLoreAdafactor(
param_groups,
lr=training_args.learning_rate,
)
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer