support galore

This commit is contained in:
hiyouga 2024-03-07 22:41:36 +08:00
parent 725f7cd70f
commit 28f7862188
12 changed files with 115 additions and 16 deletions

View File

@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/03/07] We supported [GaLore](https://arxiv.org/abs/2403.03507) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.) [24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.

View File

@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/03/07] 我们支持了 [GaLore](https://arxiv.org/abs/2403.03507) 算法。请使用 `--use_galore` 参数切换显存高效的优化器。
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。 [24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA请先合并权重。
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。

View File

@ -9,9 +9,6 @@ scipy
einops einops
sentencepiece sentencepiece
protobuf protobuf
jieba
rouge-chinese
nltk
uvicorn uvicorn
pydantic pydantic
fastapi fastapi

View File

@ -21,6 +21,10 @@ def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2") return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
def is_galore_available():
return _is_package_available("galore_torch")
def is_jieba_available(): def is_jieba_available():
return _is_package_available("jieba") return _is_package_available("jieba")

View File

@ -157,7 +157,39 @@ class RLHFArguments:
@dataclass @dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): class GaloreArguments:
r"""
Arguments pertaining to the GaLore optimization.
"""
use_galore: bool = field(
default=False,
metadata={"help": "Whether or not to use galore optimizer."},
)
galore_target: str = field(
default="mlp,attn",
metadata={"help": "Name(s) of modules to apply GaLore."},
)
galore_rank: int = field(
default=16,
metadata={"help": "GaLore rank."},
)
galore_update_interval: int = field(
default=200,
metadata={"help": "Number of steps to update the GaLore projection."},
)
galore_scale: float = field(
default=0.25,
metadata={"help": "GaLore scale."},
)
galore_proj_type: Literal["std"] = field(
default="std",
metadata={"help": "Type of GaLore projection."},
)
@dataclass
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
@ -203,6 +235,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.use_llama_pro and self.finetuning_type == "full": if self.use_llama_pro and self.finetuning_type == "full":
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.") raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@ -180,7 +180,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Post-process training arguments # Post-process training arguments
if ( if (
training_args.local_rank != -1 training_args.parallel_mode.value == "distributed"
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):

View File

@ -7,9 +7,9 @@ from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...hparams import ModelArguments from ...hparams import ModelArguments
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ...train.dpo.collator import DPODataCollatorWithPadding from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
from ...train.dpo.trainer import CustomDPOTrainer from .collator import DPODataCollatorWithPadding
from ...train.utils import create_modelcard_and_push, create_ref_model from .trainer import CustomDPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -44,6 +44,7 @@ def run_dpo(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta, beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss, loss_type=finetuning_args.dpo_loss,
@ -54,6 +55,7 @@ def run_dpo(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -13,8 +13,8 @@ from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ...train.ppo.trainer import CustomPPOTrainer from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
from ...train.utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -64,7 +64,10 @@ def run_ppo(
) )
# Create optimizer and scheduler # Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) optimizer = create_custom_optimzer(model, training_args, finetuning_args)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0: if training_args.max_steps > 0:
num_training_steps = training_args.max_steps num_training_steps = training_args.max_steps
else: else:

View File

@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
from ...data import get_dataset, split_dataset from ...data import get_dataset, split_dataset
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ...train.utils import create_modelcard_and_push from ..utils import create_custom_optimzer, create_modelcard_and_push
if TYPE_CHECKING: if TYPE_CHECKING:
@ -30,12 +30,14 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -7,10 +7,10 @@ from ...extras.callbacks import FixValueHeadModelCallback
from ...extras.misc import fix_valuehead_checkpoint from ...extras.misc import fix_valuehead_checkpoint
from ...extras.ploting import plot_loss from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer from ...model import load_model, load_tokenizer
from ...train.rm.collator import PairwiseDataCollatorWithPadding from ..utils import create_custom_optimzer, create_modelcard_and_push
from ...train.rm.metric import compute_accuracy from .collator import PairwiseDataCollatorWithPadding
from ...train.rm.trainer import PairwiseTrainer from .metric import compute_accuracy
from ...train.utils import create_modelcard_and_push from .trainer import PairwiseTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -35,12 +35,14 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
trainer = PairwiseTrainer( trainer = PairwiseTrainer(
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks + [FixValueHeadModelCallback()], callbacks=callbacks + [FixValueHeadModelCallback()],
optimizers=(optimizer, None),
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -12,6 +12,7 @@ from ...model import load_model, load_tokenizer
from ...train.sft.metric import ComputeMetrics from ...train.sft.metric import ComputeMetrics
from ...train.sft.trainer import CustomSeq2SeqTrainer from ...train.sft.trainer import CustomSeq2SeqTrainer
from ...train.utils import create_modelcard_and_push from ...train.utils import create_modelcard_and_push
from ..utils import create_custom_optimzer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -49,12 +50,14 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
# Initialize our Trainer # Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(
model=model, model=model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
optimizers=(optimizer, None),
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args), **split_dataset(dataset, data_args, training_args),
) )

View File

@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import load_model_and_tokenizer, load_valuehead_params from ..model import load_model_and_tokenizer, load_valuehead_params
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, Trainer from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
@ -118,3 +123,45 @@ def create_reward_model(
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model return reward_model
def create_custom_optimzer(
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return None
galore_params = []
galore_targets = finetuning_args.galore_target.split(",")
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
id_galore_params = [id(p) for p in galore_params]
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
# define param groups as galore_params and non_galore_params
param_groups = [
{"params": non_galore_params},
{
"params": galore_params,
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adamw_8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer