diff --git a/README.md b/README.md index 605e6ad9..b9a7f7dc 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/03/20] We supported **FSDP+QLoRA** that fine-tunes a 70B model on 2x24GB GPUs. See `examples/fsdp_qlora` for usage. -[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. Try `loraplus_lr_ratio=16.0` to enable LoRA+ algorithm. +[24/03/13] We supported **[LoRA+](https://arxiv.org/abs/2402.12354)**. See `examples/extras/loraplus` for usage. -[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. Try `--use_galore` to use the memory-efficient optimizer. +[24/03/07] We supported gradient low-rank projection (**[GaLore](https://arxiv.org/abs/2403.03507)**) algorithm. See `examples/extras/galore` for usage. [24/03/07] We integrated **[vLLM](https://github.com/vllm-project/vllm)** for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) diff --git a/README_zh.md b/README_zh.md index 242c1ff7..beb88e7d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -72,9 +72,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [24/03/20] 我们支持了能在 2x24GB GPU 上微调 70B 模型的 **FSDP+QLoRA**。详细用法请参照 `examples/fsdp_qlora`。 -[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。请使用 `loraplus_lr_ratio=16.0` 参数开启 LoRA+ 方法。 +[24/03/13] 我们支持了 **[LoRA+](https://arxiv.org/abs/2402.12354)**。详细用法请参照 `examples/extras/loraplus`。 -[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。请使用 `--use_galore` 参数切换显存高效的优化器。 +[24/03/07] 我们支持了梯度低秩投影(**[GaLore](https://arxiv.org/abs/2403.03507)**)算法。详细用法请参照 `examples/extras/galore`。 [24/03/07] 我们集成了 **[vLLM](https://github.com/vllm-project/vllm)** 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。) diff --git a/examples/extras/galore/adamw.sh b/examples/extras/galore/adamw.sh deleted file mode 100644 index d4f5afb4..00000000 --- a/examples/extras/galore/adamw.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ - --stage sft \ - --do_train \ - --model_name_or_path meta-llama/Llama-2-7b-hf \ - --dataset alpaca_gpt4_en,glaive_toolcall \ - --dataset_dir ../../../data \ - --template default \ - --finetuning_type full \ - --output_dir ../../../saves/LLaMA2-7B/galore/sft \ - --overwrite_cache \ - --overwrite_output_dir \ - --cutoff_len 1024 \ - --preprocessing_num_workers 16 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --lr_scheduler_type cosine \ - --logging_steps 10 \ - --warmup_steps 20 \ - --save_steps 100 \ - --eval_steps 100 \ - --evaluation_strategy steps \ - --load_best_model_at_end \ - --learning_rate 5e-5 \ - --num_train_epochs 3.0 \ - --max_samples 3000 \ - --val_size 0.1 \ - --plot_loss \ - --fp16 diff --git a/examples/extras/galore/adamw_8bit_bf16.sh b/examples/extras/galore/adamw_8bit_bf16.sh deleted file mode 100644 index ecb4fa96..00000000 --- a/examples/extras/galore/adamw_8bit_bf16.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ - --stage sft \ - --do_train \ - --model_name_or_path meta-llama/Llama-2-7b-hf \ - --dataset alpaca_gpt4_en,glaive_toolcall \ - --dataset_dir ../../../data \ - --template default \ - --finetuning_type full \ - --optim adamw_8bit \ - --output_dir ../../../saves/LLaMA2-7B/galore/sft \ - --overwrite_cache \ - --overwrite_output_dir \ - --cutoff_len 1024 \ - --preprocessing_num_workers 16 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --lr_scheduler_type cosine \ - --logging_steps 10 \ - --warmup_steps 20 \ - --save_steps 100 \ - --eval_steps 100 \ - --evaluation_strategy steps \ - --load_best_model_at_end \ - --learning_rate 5e-5 \ - --num_train_epochs 3.0 \ - --max_samples 3000 \ - --val_size 0.1 \ - --plot_loss \ - --pure_bf16 diff --git a/examples/extras/galore/galore_adamw_8bit_bf16.sh b/examples/extras/galore/galore_adamw_8bit_bf16.sh deleted file mode 100644 index cedc8bee..00000000 --- a/examples/extras/galore/galore_adamw_8bit_bf16.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ - --stage sft \ - --do_train \ - --model_name_or_path meta-llama/Llama-2-7b-hf \ - --dataset alpaca_gpt4_en,glaive_toolcall \ - --dataset_dir ../../../data \ - --template default \ - --finetuning_type full \ - --optim adamw_8bit \ - --use_galore \ - --galore_layerwise \ - --galore_target mlp,self_attn \ - --galore_rank 128 \ - --output_dir ../../../saves/LLaMA2-7B/galore/sft \ - --overwrite_cache \ - --overwrite_output_dir \ - --cutoff_len 1024 \ - --preprocessing_num_workers 16 \ - --per_device_train_batch_size 1 \ - --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 1 \ - --lr_scheduler_type cosine \ - --logging_steps 10 \ - --warmup_steps 20 \ - --save_steps 100 \ - --eval_steps 100 \ - --evaluation_strategy steps \ - --load_best_model_at_end \ - --learning_rate 5e-5 \ - --num_train_epochs 3.0 \ - --max_samples 3000 \ - --val_size 0.1 \ - --plot_loss \ - --pure_bf16 diff --git a/examples/extras/galore/galore_adamw.sh b/examples/extras/galore/sft.sh similarity index 98% rename from examples/extras/galore/galore_adamw.sh rename to examples/extras/galore/sft.sh index 063bb6df..1ffeb5ca 100644 --- a/examples/extras/galore/galore_adamw.sh +++ b/examples/extras/galore/sft.sh @@ -32,4 +32,4 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --max_samples 3000 \ --val_size 0.1 \ --plot_loss \ - --fp16 + --pure_bf16 diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 5a437491..1b96a9dd 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -47,6 +47,8 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: output_layer_names = ["lm_head"] if model.config.model_type == "chatglm": output_layer_names.append("output_layer") + elif model.config.model_type == "internlm2": + output_layer_names.append("output") module_names = set() for name, module in model.named_modules(): diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index ed0fe5f1..39e84679 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -8,7 +8,7 @@ from trl import DPOTrainer from trl.trainer.utils import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer +from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -63,12 +63,16 @@ class CustomDPOTrainer(DPOTrainer): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) - def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() - self.create_optimizer() - self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: r""" diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index dff135d2..658b244d 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -13,7 +13,7 @@ from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_custom_optimzer, create_ref_model, create_reward_model +from ..utils import create_custom_optimzer, create_custom_scheduler, create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -70,7 +70,8 @@ def run_ppo( total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) - optimizer = create_custom_optimzer(model, training_args, finetuning_args, num_training_steps) + optimizer = create_custom_optimzer(model, training_args, finetuning_args) + create_custom_scheduler(training_args, num_training_steps, optimizer) if optimizer is None: optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) diff --git a/src/llmtuner/train/pt/trainer.py b/src/llmtuner/train/pt/trainer.py index 16e3f5f0..af2848fb 100644 --- a/src/llmtuner/train/pt/trainer.py +++ b/src/llmtuner/train/pt/trainer.py @@ -1,12 +1,14 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer +from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + import torch + from ...hparams import FinetuningArguments @@ -22,9 +24,13 @@ class CustomTrainer(Trainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args - def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() - self.create_optimizer() - self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index 4f5d7190..8d0f2763 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -1,12 +1,12 @@ import json import os -from typing import TYPE_CHECKING, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer +from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -29,12 +29,16 @@ class PairwiseTrainer(Trainer): self.finetuning_args = finetuning_args self.can_return_loss = True # override property to return eval_loss - def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() - self.create_optimizer() - self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) def compute_loss( self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 4a49bb27..8d2f9fa0 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -8,7 +8,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from ..utils import create_custom_optimzer +from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: @@ -29,12 +29,16 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): super().__init__(**kwargs) self.finetuning_args = finetuning_args - def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: + def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() - self.create_optimizer() - self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) def prediction_step( self, diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 49c42d4e..73854a5e 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -5,6 +5,7 @@ from transformers import Trainer from transformers.optimization import get_scheduler from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names +from transformers.utils.versions import require_version from ..extras.logging import get_logger from ..extras.packages import is_galore_available @@ -28,7 +29,13 @@ logger = get_logger(__name__) class DummyOptimizer(torch.optim.Optimizer): - def __init__(self, lr: float = 1e-3, optimizer_dict: Optional[dict] = None, *args, **kwargs) -> None: + r""" + A dummy optimizer used for the GaLore algorithm. + """ + + def __init__( + self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None + ) -> None: dummy_tensor = torch.randn(1, 1) self.optimizer_dict = optimizer_dict super().__init__([dummy_tensor], {"lr": lr}) @@ -155,8 +162,9 @@ def _create_galore_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - max_steps: int, ) -> "torch.optim.Optimizer": + require_version("galore_torch", "To fix: pip install galore_torch") + if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": galore_targets = find_all_linear_modules(model) else: @@ -211,29 +219,19 @@ def _create_galore_optimizer( for param in decay_params: param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) - for param in galore_params: + for param in galore_params: # galore params have weight decay param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) - scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {} - for param in trainable_params: - scheduler_dict[param] = get_scheduler( - training_args.lr_scheduler_type, - optimizer=optimizer_dict[param], - num_warmup_steps=training_args.get_warmup_steps(max_steps) * 2, - num_training_steps=max_steps * 2, - ) - - def optimizer_hook(param: "torch.Tensor"): + def optimizer_hook(param: "torch.nn.Parameter"): if param.grad is not None: optimizer_dict[param].step() optimizer_dict[param].zero_grad() - scheduler_dict[param].step() for param in trainable_params: param.register_post_accumulate_grad_hook(optimizer_hook) - optimizer = DummyOptimizer(lr=training_args.learning_rate) # display scheduler result + optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) else: param_groups = [ dict(params=nodecay_params), @@ -292,10 +290,34 @@ def create_custom_optimzer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - max_steps: int, ) -> Optional["torch.optim.Optimizer"]: if finetuning_args.use_galore: - return _create_galore_optimizer(model, training_args, finetuning_args, max_steps) + return _create_galore_optimizer(model, training_args, finetuning_args) if finetuning_args.loraplus_lr_ratio is not None: return _create_loraplus_optimizer(model, training_args, finetuning_args) + + +def create_custom_scheduler( + training_args: "Seq2SeqTrainingArguments", + num_training_steps: int, + optimizer: Optional["torch.optim.Optimizer"] = None, +) -> None: + if optimizer is not None and isinstance(optimizer, DummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer_dict[param], + num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2, + num_training_steps=num_training_steps * 2, + ) + + def scheduler_hook(param: "torch.nn.Parameter"): + if param.grad is not None: + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook)