support layerwise galore

This commit is contained in:
hiyouga 2024-03-10 00:24:11 +08:00
parent 18ffce36b5
commit 8664262cde
14 changed files with 109 additions and 51 deletions

View File

@ -276,16 +276,13 @@ huggingface-cli login
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
> [!NOTE]
> We report the GaLore results without per-layer weight updates.
## Getting Started
### Data Preparation (optional)

View File

@ -276,16 +276,13 @@ huggingface-cli login
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB |
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
> [!NOTE]
> 上述 GaLore 的结果中不包含逐层权重更新。
## 如何使用
### 数据准备(可跳过)

View File

@ -15,7 +15,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \

View File

@ -16,7 +16,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \

View File

@ -9,8 +9,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--template default \
--finetuning_type full \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 32 \
--galore_rank 128 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
@ -18,7 +19,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \

View File

@ -10,8 +10,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--finetuning_type full \
--optim adamw_8bit \
--use_galore \
--galore_layerwise \
--galore_target mlp,self_attn \
--galore_rank 16 \
--galore_rank 128 \
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
--overwrite_cache \
--overwrite_output_dir \
@ -19,7 +20,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
--preprocessing_num_workers 16 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 2 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--warmup_steps 20 \

View File

@ -29,7 +29,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr",
model_args: "ModelArguments",
data_args: "DataArguments",
):
) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:

View File

@ -182,6 +182,10 @@ class GaloreArguments:
default="std",
metadata={"help": "Type of GaLore projection."},
)
galore_layerwise: bool = field(
default=False,
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
)
@dataclass

View File

@ -44,7 +44,7 @@ def run_dpo(
training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta,
loss_type=finetuning_args.dpo_loss,

View File

@ -64,7 +64,7 @@ def run_ppo(
)
# Create optimizer and scheduler
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
if optimizer is None:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)

View File

@ -30,7 +30,7 @@ def run_pt(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = Trainer(
model=model,
args=training_args,

View File

@ -35,7 +35,7 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset
# Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = PairwiseTrainer(
model=model,
args=training_args,

View File

@ -50,7 +50,7 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
# Initialize our Trainer
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,

View File

@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Optional, Union
import math
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
import torch
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
@ -14,6 +16,7 @@ if is_galore_available():
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments, Trainer
from transformers.modeling_utils import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
@ -24,6 +27,18 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class DummyOptimizer(torch.optim.Optimizer):
def __init__(self, *args, **kwargs):
dummy_tensor = torch.randn(1, 1)
super().__init__([dummy_tensor], {"lr": 1e-3})
def zero_grad(self, set_to_none: bool = True) -> None:
pass
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass
def create_modelcard_and_push(
trainer: "Trainer",
model_args: "ModelArguments",
@ -127,7 +142,10 @@ def create_reward_model(
def create_custom_optimzer(
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
model: "PreTrainedModel",
dataset: Union["Dataset", "IterableDataset"],
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> Optional["torch.optim.Optimizer"]:
if not finetuning_args.use_galore:
return None
@ -144,40 +162,80 @@ def create_custom_optimzer(
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
# define param groups as galore_params and non_galore_params
param_groups = [
{"params": non_galore_params},
{
"params": galore_params,
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
},
]
if training_args.optim == "adamw_torch":
optimizer = GaLoreAdamW(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
)
optim_class = GaLoreAdamW
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
}
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
optimizer = GaLoreAdamW8bit(
param_groups,
lr=training_args.learning_rate,
eps=training_args.adam_epsilon,
betas=(training_args.adam_beta1, training_args.adam_beta2),
optim_bits=8,
is_paged="paged" in training_args.optim,
)
optim_class = GaLoreAdamW8bit
optim_kwargs = {
"lr": training_args.learning_rate,
"eps": training_args.adam_epsilon,
"betas": (training_args.adam_beta1, training_args.adam_beta2),
"optim_bits": 8,
"is_paged": "paged" in training_args.optim,
}
elif training_args.optim == "adafactor":
optimizer = GaLoreAdafactor(
param_groups,
lr=training_args.learning_rate,
)
optim_class = GaLoreAdafactor
optim_kwargs = {
"lr": training_args.learning_rate,
}
else:
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
galore_kwargs = {
"rank": finetuning_args.galore_rank,
"update_proj_gap": finetuning_args.galore_update_interval,
"scale": finetuning_args.galore_scale,
"proj_type": finetuning_args.galore_proj_type,
}
if finetuning_args.galore_layerwise:
if training_args.gradient_accumulation_steps != 1:
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
else:
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
for param in non_galore_params:
param_groups = [dict(params=[param])]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
for param in galore_params:
param_groups = [dict(params=[param], **galore_kwargs)]
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
for param in non_galore_params + galore_params:
scheduler_dict[param] = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer_dict[param],
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
num_training_steps=num_training_steps * 2,
)
def optimizer_hook(param: "torch.Tensor"):
if param.grad is not None:
optimizer_dict[param].step()
optimizer_dict[param].zero_grad()
scheduler_dict[param].step()
for param in non_galore_params + galore_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
optimizer = DummyOptimizer()
else:
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
optimizer = optim_class(param_groups, **optim_kwargs)
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
return optimizer