support galore
This commit is contained in:
parent
725f7cd70f
commit
28f7862188
|
@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||
|
||||
## Changelog
|
||||
|
||||
[24/03/07] We supported [GaLore](https://arxiv.org/abs/2403.03507) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||
|
||||
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||
|
||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||
|
|
|
@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
|||
|
||||
## 更新日志
|
||||
|
||||
[24/03/07] 我们支持了 [GaLore](https://arxiv.org/abs/2403.03507) 算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||
|
||||
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||
|
||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||
|
|
|
@ -9,9 +9,6 @@ scipy
|
|||
einops
|
||||
sentencepiece
|
||||
protobuf
|
||||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
uvicorn
|
||||
pydantic
|
||||
fastapi
|
||||
|
|
|
@ -21,6 +21,10 @@ def is_flash_attn2_available():
|
|||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||
|
||||
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
|
||||
def is_jieba_available():
|
||||
return _is_package_available("jieba")
|
||||
|
||||
|
|
|
@ -157,7 +157,39 @@ class RLHFArguments:
|
|||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
class GaloreArguments:
|
||||
r"""
|
||||
Arguments pertaining to the GaLore optimization.
|
||||
"""
|
||||
|
||||
use_galore: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use galore optimizer."},
|
||||
)
|
||||
galore_target: str = field(
|
||||
default="mlp,attn",
|
||||
metadata={"help": "Name(s) of modules to apply GaLore."},
|
||||
)
|
||||
galore_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "GaLore rank."},
|
||||
)
|
||||
galore_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||
)
|
||||
galore_scale: float = field(
|
||||
default=0.25,
|
||||
metadata={"help": "GaLore scale."},
|
||||
)
|
||||
galore_proj_type: Literal["std"] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of GaLore projection."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
|
@ -203,6 +235,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
|||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
if self.use_galore and self.finetuning_type == "lora":
|
||||
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
|
|
|
@ -180,7 +180,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
training_args.parallel_mode.value == "distributed"
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
|
|
|
@ -7,9 +7,9 @@ from ...extras.constants import IGNORE_INDEX
|
|||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.dpo.collator import DPODataCollatorWithPadding
|
||||
from ...train.dpo.trainer import CustomDPOTrainer
|
||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
|
||||
from .collator import DPODataCollatorWithPadding
|
||||
from .trainer import CustomDPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -44,6 +44,7 @@ def run_dpo(
|
|||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
|
@ -54,6 +55,7 @@ def run_dpo(
|
|||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
|||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.ppo.trainer import CustomPPOTrainer
|
||||
from ...train.utils import create_ref_model, create_reward_model
|
||||
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
|
||||
from .trainer import CustomPPOTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -64,7 +64,10 @@ def run_ppo(
|
|||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
if optimizer is None:
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
|
|
|
@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
|
|||
from ...data import get_dataset, split_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -30,12 +30,14 @@ def run_pt(
|
|||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
|||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from ...train.rm.metric import compute_accuracy
|
||||
from ...train.rm.trainer import PairwiseTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||
from .collator import PairwiseDataCollatorWithPadding
|
||||
from .metric import compute_accuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -35,12 +35,14 @@ def run_rm(
|
|||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
optimizers=(optimizer, None),
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from ...model import load_model, load_tokenizer
|
|||
from ...train.sft.metric import ComputeMetrics
|
||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||
from ...train.utils import create_modelcard_and_push
|
||||
from ..utils import create_custom_optimzer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -49,12 +50,14 @@ def run_sft(
|
|||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
optimizers=(optimizer, None),
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
|
|
@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
import torch
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
@ -118,3 +123,45 @@ def create_reward_model(
|
|||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||
return reward_model
|
||||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
|
||||
galore_params = []
|
||||
galore_targets = finetuning_args.galore_target.split(",")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||
|
||||
# define param groups as galore_params and non_galore_params
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adamw_8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
|
||||
return optimizer
|
||||
|
|
Loading…
Reference in New Issue