support galore
This commit is contained in:
parent
725f7cd70f
commit
28f7862188
|
@ -70,6 +70,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
[24/03/07] We supported [GaLore](https://arxiv.org/abs/2403.03507) algorithm. Try `--use_galore` to use the memory-efficient optimizer.
|
||||||
|
|
||||||
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
[24/03/07] We integrated [vLLM](https://github.com/vllm-project/vllm) for faster and concurrent inference. Try `--infer_backend vllm` to enjoy **270%** inference speed. (LoRA is not yet supported, merge it first.)
|
||||||
|
|
||||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||||
|
|
|
@ -70,6 +70,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
|
[24/03/07] 我们支持了 [GaLore](https://arxiv.org/abs/2403.03507) 算法。请使用 `--use_galore` 参数切换显存高效的优化器。
|
||||||
|
|
||||||
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
[24/03/07] 我们集成了 [vLLM](https://github.com/vllm-project/vllm) 以实现极速并发推理。请使用 `--infer_backend vllm` 来获得 **270%** 的推理速度。(尚不支持 LoRA,请先合并权重。)
|
||||||
|
|
||||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||||
|
|
|
@ -9,9 +9,6 @@ scipy
|
||||||
einops
|
einops
|
||||||
sentencepiece
|
sentencepiece
|
||||||
protobuf
|
protobuf
|
||||||
jieba
|
|
||||||
rouge-chinese
|
|
||||||
nltk
|
|
||||||
uvicorn
|
uvicorn
|
||||||
pydantic
|
pydantic
|
||||||
fastapi
|
fastapi
|
||||||
|
|
|
@ -21,6 +21,10 @@ def is_flash_attn2_available():
|
||||||
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
|
||||||
|
|
||||||
|
|
||||||
|
def is_galore_available():
|
||||||
|
return _is_package_available("galore_torch")
|
||||||
|
|
||||||
|
|
||||||
def is_jieba_available():
|
def is_jieba_available():
|
||||||
return _is_package_available("jieba")
|
return _is_package_available("jieba")
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,39 @@ class RLHFArguments:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
class GaloreArguments:
|
||||||
|
r"""
|
||||||
|
Arguments pertaining to the GaLore optimization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_galore: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use galore optimizer."},
|
||||||
|
)
|
||||||
|
galore_target: str = field(
|
||||||
|
default="mlp,attn",
|
||||||
|
metadata={"help": "Name(s) of modules to apply GaLore."},
|
||||||
|
)
|
||||||
|
galore_rank: int = field(
|
||||||
|
default=16,
|
||||||
|
metadata={"help": "GaLore rank."},
|
||||||
|
)
|
||||||
|
galore_update_interval: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={"help": "Number of steps to update the GaLore projection."},
|
||||||
|
)
|
||||||
|
galore_scale: float = field(
|
||||||
|
default=0.25,
|
||||||
|
metadata={"help": "GaLore scale."},
|
||||||
|
)
|
||||||
|
galore_proj_type: Literal["std"] = field(
|
||||||
|
default="std",
|
||||||
|
metadata={"help": "Type of GaLore projection."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
@ -203,6 +235,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||||
if self.use_llama_pro and self.finetuning_type == "full":
|
if self.use_llama_pro and self.finetuning_type == "full":
|
||||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||||
|
|
||||||
|
if self.use_galore and self.finetuning_type == "lora":
|
||||||
|
raise ValueError("Cannot use LoRA with GaLore together.")
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
|
|
|
@ -180,7 +180,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
training_args.local_rank != -1
|
training_args.parallel_mode.value == "distributed"
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
|
|
|
@ -7,9 +7,9 @@ from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.dpo.collator import DPODataCollatorWithPadding
|
from ..utils import create_custom_optimzer, create_modelcard_and_push, create_ref_model
|
||||||
from ...train.dpo.trainer import CustomDPOTrainer
|
from .collator import DPODataCollatorWithPadding
|
||||||
from ...train.utils import create_modelcard_and_push, create_ref_model
|
from .trainer import CustomDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -44,6 +44,7 @@ def run_dpo(
|
||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
loss_type=finetuning_args.dpo_loss,
|
loss_type=finetuning_args.dpo_loss,
|
||||||
|
@ -54,6 +55,7 @@ def run_dpo(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
optimizers=(optimizer, None),
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,8 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.ppo.trainer import CustomPPOTrainer
|
from ..utils import create_custom_optimzer, create_ref_model, create_reward_model
|
||||||
from ...train.utils import create_ref_model, create_reward_model
|
from .trainer import CustomPPOTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -64,7 +64,10 @@ def run_ppo(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
|
|
||||||
if training_args.max_steps > 0:
|
if training_args.max_steps > 0:
|
||||||
num_training_steps = training_args.max_steps
|
num_training_steps = training_args.max_steps
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
|
||||||
from ...data import get_dataset, split_dataset
|
from ...data import get_dataset, split_dataset
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -30,12 +30,14 @@ def run_pt(
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
optimizers=(optimizer, None),
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,10 @@ from ...extras.callbacks import FixValueHeadModelCallback
|
||||||
from ...extras.misc import fix_valuehead_checkpoint
|
from ...extras.misc import fix_valuehead_checkpoint
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ...train.rm.collator import PairwiseDataCollatorWithPadding
|
from ..utils import create_custom_optimzer, create_modelcard_and_push
|
||||||
from ...train.rm.metric import compute_accuracy
|
from .collator import PairwiseDataCollatorWithPadding
|
||||||
from ...train.rm.trainer import PairwiseTrainer
|
from .metric import compute_accuracy
|
||||||
from ...train.utils import create_modelcard_and_push
|
from .trainer import PairwiseTrainer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -35,12 +35,14 @@ def run_rm(
|
||||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||||
trainer = PairwiseTrainer(
|
trainer = PairwiseTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||||
|
optimizers=(optimizer, None),
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
|
@ -12,6 +12,7 @@ from ...model import load_model, load_tokenizer
|
||||||
from ...train.sft.metric import ComputeMetrics
|
from ...train.sft.metric import ComputeMetrics
|
||||||
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
from ...train.sft.trainer import CustomSeq2SeqTrainer
|
||||||
from ...train.utils import create_modelcard_and_push
|
from ...train.utils import create_modelcard_and_push
|
||||||
|
from ..utils import create_custom_optimzer
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -49,12 +50,14 @@ def run_sft(
|
||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
optimizers=(optimizer, None),
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**split_dataset(dataset, data_args, training_args),
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,10 +3,15 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
from ..extras.packages import is_galore_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import load_model_and_tokenizer, load_valuehead_params
|
from ..model import load_model_and_tokenizer, load_valuehead_params
|
||||||
|
|
||||||
|
|
||||||
|
if is_galore_available():
|
||||||
|
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
@ -118,3 +123,45 @@ def create_reward_model(
|
||||||
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_optimzer(
|
||||||
|
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||||
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
|
if not finetuning_args.use_galore:
|
||||||
|
return None
|
||||||
|
|
||||||
|
galore_params = []
|
||||||
|
galore_targets = finetuning_args.galore_target.split(",")
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets):
|
||||||
|
galore_params += list(filter(lambda p: p.requires_grad, module.parameters()))
|
||||||
|
|
||||||
|
id_galore_params = [id(p) for p in galore_params]
|
||||||
|
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||||
|
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||||
|
|
||||||
|
# define param groups as galore_params and non_galore_params
|
||||||
|
param_groups = [
|
||||||
|
{"params": non_galore_params},
|
||||||
|
{
|
||||||
|
"params": galore_params,
|
||||||
|
"rank": finetuning_args.galore_rank,
|
||||||
|
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||||
|
"scale": finetuning_args.galore_scale,
|
||||||
|
"proj_type": finetuning_args.galore_proj_type,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if training_args.optim == "adamw_torch":
|
||||||
|
optimizer = GaLoreAdamW(param_groups, lr=training_args.learning_rate)
|
||||||
|
elif training_args.optim == "adamw_8bit":
|
||||||
|
optimizer = GaLoreAdamW8bit(param_groups, lr=training_args.learning_rate)
|
||||||
|
elif training_args.optim == "adafactor":
|
||||||
|
optimizer = GaLoreAdafactor(param_groups, lr=training_args.learning_rate)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||||
|
|
||||||
|
logger.info("Used the GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
Loading…
Reference in New Issue