support layerwise galore
This commit is contained in:
parent
18ffce36b5
commit
8664262cde
|
@ -276,16 +276,13 @@ huggingface-cli login
|
|||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
|
||||
> [!NOTE]
|
||||
> We report the GaLore results without per-layer weight updates.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Data Preparation (optional)
|
||||
|
|
|
@ -276,16 +276,13 @@ huggingface-cli login
|
|||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB |
|
||||
| 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB |
|
||||
| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB |
|
||||
| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB |
|
||||
| LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB |
|
||||
| QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB |
|
||||
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB |
|
||||
| QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB |
|
||||
|
||||
> [!NOTE]
|
||||
> 上述 GaLore 的结果中不包含逐层权重更新。
|
||||
|
||||
## 如何使用
|
||||
|
||||
### 数据准备(可跳过)
|
||||
|
|
|
@ -15,7 +15,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
|
|
|
@ -16,7 +16,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
|
|
|
@ -9,8 +9,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--template default \
|
||||
--finetuning_type full \
|
||||
--use_galore \
|
||||
--galore_layerwise \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 32 \
|
||||
--galore_rank 128 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
@ -18,7 +19,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
|
|
|
@ -10,8 +10,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--finetuning_type full \
|
||||
--optim adamw_8bit \
|
||||
--use_galore \
|
||||
--galore_layerwise \
|
||||
--galore_target mlp,self_attn \
|
||||
--galore_rank 16 \
|
||||
--galore_rank 128 \
|
||||
--output_dir ../../../saves/LLaMA2-7B/galore/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
|
@ -19,7 +20,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \
|
|||
--preprocessing_num_workers 16 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--warmup_steps 20 \
|
||||
|
|
|
@ -29,7 +29,7 @@ def load_single_dataset(
|
|||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
):
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
|
|
|
@ -182,6 +182,10 @@ class GaloreArguments:
|
|||
default="std",
|
||||
metadata={"help": "Type of GaLore projection."},
|
||||
)
|
||||
galore_layerwise: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -44,7 +44,7 @@ def run_dpo(
|
|||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
loss_type=finetuning_args.dpo_loss,
|
||||
|
|
|
@ -64,7 +64,7 @@ def run_ppo(
|
|||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
if optimizer is None:
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ def run_pt(
|
|||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
@ -35,7 +35,7 @@ def run_rm(
|
|||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
@ -50,7 +50,7 @@ def run_sft(
|
|||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
|
||||
# Initialize our Trainer
|
||||
optimizer = create_custom_optimzer(model, training_args, finetuning_args)
|
||||
optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args)
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from typing import TYPE_CHECKING, Optional, Union
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
@ -14,6 +16,7 @@ if is_galore_available():
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments, Trainer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
@ -24,6 +27,18 @@ if TYPE_CHECKING:
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
dummy_tensor = torch.randn(1, 1)
|
||||
super().__init__([dummy_tensor], {"lr": 1e-3})
|
||||
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
pass
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
pass
|
||||
|
||||
|
||||
def create_modelcard_and_push(
|
||||
trainer: "Trainer",
|
||||
model_args: "ModelArguments",
|
||||
|
@ -127,7 +142,10 @@ def create_reward_model(
|
|||
|
||||
|
||||
def create_custom_optimzer(
|
||||
model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments"
|
||||
model: "PreTrainedModel",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
|
@ -144,40 +162,80 @@ def create_custom_optimzer(
|
|||
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
|
||||
non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params]
|
||||
|
||||
# define param groups as galore_params and non_galore_params
|
||||
param_groups = [
|
||||
{"params": non_galore_params},
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
},
|
||||
]
|
||||
if training_args.optim == "adamw_torch":
|
||||
optimizer = GaLoreAdamW(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
)
|
||||
optim_class = GaLoreAdamW
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
"eps": training_args.adam_epsilon,
|
||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||
}
|
||||
|
||||
elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]:
|
||||
optimizer = GaLoreAdamW8bit(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
eps=training_args.adam_epsilon,
|
||||
betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
optim_bits=8,
|
||||
is_paged="paged" in training_args.optim,
|
||||
)
|
||||
optim_class = GaLoreAdamW8bit
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
"eps": training_args.adam_epsilon,
|
||||
"betas": (training_args.adam_beta1, training_args.adam_beta2),
|
||||
"optim_bits": 8,
|
||||
"is_paged": "paged" in training_args.optim,
|
||||
}
|
||||
|
||||
elif training_args.optim == "adafactor":
|
||||
optimizer = GaLoreAdafactor(
|
||||
param_groups,
|
||||
lr=training_args.learning_rate,
|
||||
)
|
||||
optim_class = GaLoreAdafactor
|
||||
optim_kwargs = {
|
||||
"lr": training_args.learning_rate,
|
||||
}
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Unknow optim: {}".format(training_args.optim))
|
||||
|
||||
galore_kwargs = {
|
||||
"rank": finetuning_args.galore_rank,
|
||||
"update_proj_gap": finetuning_args.galore_update_interval,
|
||||
"scale": finetuning_args.galore_scale,
|
||||
"proj_type": finetuning_args.galore_proj_type,
|
||||
}
|
||||
|
||||
if finetuning_args.galore_layerwise:
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer GaLore does not support gradient accumulation.")
|
||||
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in non_galore_params:
|
||||
param_groups = [dict(params=[param])]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in galore_params:
|
||||
param_groups = [dict(params=[param], **galore_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {}
|
||||
for param in non_galore_params + galore_params:
|
||||
scheduler_dict[param] = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer_dict[param],
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2,
|
||||
num_training_steps=num_training_steps * 2,
|
||||
)
|
||||
|
||||
def optimizer_hook(param: "torch.Tensor"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
scheduler_dict[param].step()
|
||||
|
||||
for param in non_galore_params + galore_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer()
|
||||
else:
|
||||
param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
return optimizer
|
||||
|
|
Loading…
Reference in New Issue