support fsdp + qlora
This commit is contained in:
parent
3271af2afc
commit
8408225162
|
@ -70,17 +70,19 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/20] We supported **FSDP + QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage.
|
||||||
|
|
||||||
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
|
[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm.
|
||||||
|
|
||||||
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||||
|
|
||||||
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
[24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
||||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage.
|
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `examples/extras/llama_pro` for usage.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
|
||||||
|
|
||||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||||
|
|
||||||
|
@ -238,6 +240,7 @@ You also can add a custom chat template to [template.py](src/llmtuner/data/templ
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
|
|
|
@ -70,17 +70,19 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP + QLoRA**。详细用法请参照 `examples/fsdp_qlora`。
|
||||||
|
|
||||||
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。
|
[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。
|
||||||
|
|
||||||
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||||
|
|
||||||
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
[24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
||||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`。
|
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `examples/extras/llama_pro`。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
|
||||||
|
|
||||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||||
|
|
||||||
|
@ -238,6 +240,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Orca DPO (en)](https://huggingface.co/datasets/Intel/orca_dpo_pairs)
|
||||||
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: false
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch: BACKWARD_PRE
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_forward_prefetch: false
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: fp16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
|
@ -0,0 +1,5 @@
|
||||||
|
```bash
|
||||||
|
pip install git+https://github.com/huggingface/transformers.git
|
||||||
|
pip install "accelerate>=0.28.0"
|
||||||
|
pip install "bitsandbytes>=0.43.0"
|
||||||
|
```
|
|
@ -0,0 +1,33 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
|
||||||
|
--config_file ../accelerate/fsdp_config.yaml \
|
||||||
|
../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-70b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--output_dir ../../saves/LLaMA2-70B/lora/sft \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 1 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--max_samples 3000 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--quantization_bit 4 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
|
@ -7,11 +7,11 @@ python -m torch.distributed.run \
|
||||||
--master_addr $MASTER_ADDR \
|
--master_addr $MASTER_ADDR \
|
||||||
--master_port $MASTER_PORT \
|
--master_port $MASTER_PORT \
|
||||||
../../src/train_bash.py \
|
../../src/train_bash.py \
|
||||||
--deepspeed ds_z3_config.json \
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--do_train \
|
--do_train \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
--dataset_dir ../../data \
|
--dataset_dir ../../data \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type full \
|
--finetuning_type full \
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||||
--deepspeed ds_z3_config.json \
|
--deepspeed ../deepspeed/ds_z3_config.json \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--do_train \
|
--do_train \
|
||||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
--dataset_dir ../../data \
|
--dataset_dir ../../data \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type full \
|
--finetuning_type full \
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch \
|
||||||
--config_file master_config.yaml \
|
--config_file ../accelerate/master_config.yaml \
|
||||||
../../src/train_bash.py \
|
../../src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
|
||||||
--config_file single_config.yaml \
|
--config_file ../accelerate/single_config.yaml \
|
||||||
../../src/train_bash.py \
|
../../src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--do_train \
|
--do_train \
|
||||||
|
|
|
@ -3,7 +3,7 @@ transformers>=4.37.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.27.2
|
accelerate>=0.27.2
|
||||||
peft>=0.9.0
|
peft>=0.9.0
|
||||||
trl>=0.7.11
|
trl>=0.8.1
|
||||||
gradio>=3.38.0,<4.0.0
|
gradio>=3.38.0,<4.0.0
|
||||||
scipy
|
scipy
|
||||||
einops
|
einops
|
||||||
|
|
|
@ -65,7 +65,7 @@ def check_dependencies() -> None:
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
require_version("trl>=0.8.1", "To fix: pip install trl>=0.8.1")
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
@ -81,7 +81,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
num_params = num_params * 2
|
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
|
||||||
|
num_params = num_params * 2 * num_bytes
|
||||||
|
|
||||||
all_param += num_params
|
all_param += num_params
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
|
|
|
@ -210,9 +210,6 @@ def _configure_quantization(
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # bnb
|
elif model_args.quantization_bit is not None: # bnb
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
@ -224,6 +221,7 @@ def _configure_quantization(
|
||||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||||
bnb_4bit_quant_type=model_args.quantization_type,
|
bnb_4bit_quant_type=model_args.quantization_type,
|
||||||
|
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora
|
||||||
)
|
)
|
||||||
|
|
||||||
init_kwargs["device_map"] = {"": get_current_device()}
|
init_kwargs["device_map"] = {"": get_current_device()}
|
||||||
|
@ -300,7 +298,7 @@ def patch_config(
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
if model_args.low_cpu_mem_usage:
|
if init_kwargs["low_cpu_mem_usage"]:
|
||||||
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
if "device_map" not in init_kwargs: # quant models cannot use auto device map
|
||||||
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue