fix galore

This commit is contained in:
hiyouga 2024-03-08 00:44:51 +08:00
parent 57452a4aa1
commit 33a4c24a8a
11 changed files with 129 additions and 25 deletions

View File

@ -70,7 +70,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/03/07] We supported **[GaLore](https://arxiv.org/abs/2403.03507)** algorithm. Try `--use_galore` to use the memory-efficient optimizer. [24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)

View File

@ -70,7 +70,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/03/07] 我们支持了 **[GaLore](https://arxiv.org/abs/2403.03507)** 算法。请使用 `--use_galore` 参数切换显存高效的优化器。 [24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**算法。请使用 `--use_galore` 参数切换显存高效的优化器。
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。

View File

@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset alpaca_gpt4_en,glaive_toolcall \ --dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type freeze \ --finetuning_type full \
--name_module_trainable mlp,self_attn \
--num_layer_trainable 8 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \ --overwrite_cache \
--overwrite_output_dir \ --overwrite_output_dir \

View File

@ -0,0 +1,32 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -7,9 +7,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--dataset alpaca_gpt4_en,glaive_toolcall \ --dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \ --dataset_dir ../../../data \
--template default \ --template default \
--finetuning_type freeze \ --finetuning_type full \
--name_module_trainable mlp,self_attn \
--num_layer_trainable 8 \
--use_galore \ --use_galore \
--galore_target mlp,self_attn \ --galore_target mlp,self_attn \
--galore_rank 32 \ --galore_rank 32 \

View File

@ -0,0 +1,35 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--stage sft \
--do_train \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--dataset alpaca_gpt4_en,glaive_toolcall \
--dataset_dir ../../../data \
--template default \
--finetuning_type full \
--use_galore \
--galore_target mlp,self_attn \
--galore_rank 32 \
--optim adamw_8bit \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
--cutoff_len 1024 \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \
--save_steps 100 \
--eval_steps 100 \
--evaluation_strategy steps \
--load_best_model_at_end \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--max_samples 3000 \
--val_size 0.1 \
--plot_loss \
--pure_bf16

View File

@ -18,6 +18,19 @@ def get_requires():
return lines return lines
extra_require = {
"deepspeed": ["deepspeed==0.13.1"],
"metrics": ["nltk", "jieba", "rouge-chinese"],
"unsloth": ["unsloth[cu121-ampere-torch220] @ git+https://github.com/unslothai/unsloth.git"],
"vllm": ["vllm==0.3.3"],
"bitsandbytes": ["bitsandbytes>=0.39.0"],
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu,cpu]"],
"galore": ["galore_torch @ git+https://github.com/jiaweizzhao/GaLore.git"],
}
def main(): def main():
setup( setup(
@ -35,6 +48,7 @@ def main():
packages=find_packages("src"), packages=find_packages("src"),
python_requires=">=3.8.0", python_requires=">=3.8.0",
install_requires=get_requires(), install_requires=get_requires(),
extras_require=extra_require,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",

View File

@ -66,10 +66,6 @@ class LoraArguments:
Others choices: the same as LLaMA.""" Others choices: the same as LLaMA."""
}, },
) )
lora_bf16_mode: bool = field(
default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: bool = field( use_rslora: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
@ -194,6 +190,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
pure_bf16: bool = field(
default=False,
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
)
stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field( stage: Literal["pt", "sft", "rm", "ppo", "dpo"] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},

View File

@ -7,6 +7,7 @@ import torch
import transformers import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_torch_bf16_gpu_available
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import check_dependencies from ..extras.misc import check_dependencies
@ -156,6 +157,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if model_args.use_unsloth: if model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.") raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.")
if training_args.fp16 or training_args.bf16:
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
if ( if (
@ -226,9 +234,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
) )
# Post-process model arguments # Post-process model arguments
model_args.compute_dtype = ( if training_args.bf16 or finetuning_args.pure_bf16:
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) model_args.compute_dtype = torch.bfloat16
) elif training_args.fp16:
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate model_args.aqlm_optimization = not training_args.predict_with_generate

View File

@ -34,7 +34,8 @@ def init_adapter(
if finetuning_args.finetuning_type == "full" and is_trainable: if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
model = model.float() if not finetuning_args.pure_bf16:
model = model.float()
if finetuning_args.finetuning_type == "freeze" and is_trainable: if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
@ -78,7 +79,8 @@ def init_adapter(
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
param.data = param.data.to(torch.float32) if not finetuning_args.pure_bf16:
param.data = param.data.to(torch.float32)
else: else:
param.requires_grad_(False) param.requires_grad_(False)
@ -150,8 +152,9 @@ def init_adapter(
) )
model = get_peft_model(model, lora_config) model = get_peft_model(model, lora_config)
for param in filter(lambda p: p.requires_grad, model.parameters()): if not finetuning_args.pure_bf16:
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32) for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

View File

@ -154,14 +154,28 @@ def create_custom_optimzer(
}, },
] ]
if training_args.optim == "adamw_torch": if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate) optimizer = GaLoreAdamW(
elif training_args.optim == "adamw_8bit": param_groups,
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate) lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
)
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optimizer = GaLoreAdamW8bit(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
optim_bits=8,
is_paged="paged" in training_args.optim,
)
elif training_args.optim == "adafactor": elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate) optimizer = GaLoreAdafactor(
param_groups,
lr=training_args.learning_rate,
)
else: else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.") logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer return optimizer