Feature BAdam
This commit is contained in:
parent
cce52351b5
commit
06c8908d3f
|
@ -0,0 +1,36 @@
|
||||||
|
# BAdam layer-wise
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
|
python ../../../src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--do_train \
|
||||||
|
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||||
|
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||||
|
--dataset_dir ../../../data \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type full \
|
||||||
|
--output_dir ../../../saves/LLaMA2-7B/badam \
|
||||||
|
--overwrite_cache \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--cutoff_len 1024 \
|
||||||
|
--preprocessing_num_workers 32 \
|
||||||
|
--per_device_train_batch_size 8 \
|
||||||
|
--per_device_eval_batch_size 5 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--warmup_steps 20 \
|
||||||
|
--save_steps 100 \
|
||||||
|
--eval_steps 100 \
|
||||||
|
--evaluation_strategy steps \
|
||||||
|
--load_best_model_at_end \
|
||||||
|
--learning_rate 5e-5 \
|
||||||
|
--num_train_epochs 3.0 \
|
||||||
|
--val_size 0.1 \
|
||||||
|
--plot_loss \
|
||||||
|
--use_badam \
|
||||||
|
--switch_mode descending \
|
||||||
|
--badam_verbose 2 \
|
||||||
|
--switch_block_every 50 \
|
||||||
|
--pure_bf16 \
|
||||||
|
|
|
@ -15,3 +15,4 @@ fastapi
|
||||||
sse-starlette
|
sse-starlette
|
||||||
matplotlib
|
matplotlib
|
||||||
fire
|
fire
|
||||||
|
badam
|
|
@ -163,6 +163,47 @@ class RLHFArguments:
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BAdamArgument:
|
||||||
|
r"""
|
||||||
|
Arguments for BAdam optimizer.
|
||||||
|
"""
|
||||||
|
use_badam: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use BAdam optimizer."},
|
||||||
|
)
|
||||||
|
badam_mode: Literal["layer", "ratio"] = field(
|
||||||
|
default="layer",
|
||||||
|
metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======== Arguments for layer-wise update ========
|
||||||
|
start_block: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The starting block index for block-wise fine-tuning."}
|
||||||
|
)
|
||||||
|
switch_block_every: Optional[int] = field(
|
||||||
|
default=50,
|
||||||
|
metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."}
|
||||||
|
)
|
||||||
|
switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
||||||
|
default="ascending",
|
||||||
|
metadata={"help": "the strategy of picking block to update."}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======== Arguments for ratio-wise update ========
|
||||||
|
badam_update_ratio: float = field(
|
||||||
|
default=0.,
|
||||||
|
metadata={"help": "The ratio of the update for the BAdam optimizer."}
|
||||||
|
)
|
||||||
|
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
||||||
|
default="adjacent",
|
||||||
|
metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."}
|
||||||
|
)
|
||||||
|
badam_verbose: int = field(
|
||||||
|
default=0,
|
||||||
|
metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"}
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GaloreArguments:
|
class GaloreArguments:
|
||||||
|
@ -204,7 +245,7 @@ class GaloreArguments:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments):
|
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -171,6 +171,12 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||||
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
raise ValueError("GaLore is incompatible with DeepSpeed.")
|
||||||
|
|
||||||
|
if (finetuning_args.use_badam
|
||||||
|
and finetuning_args.badam_mode == "layer"
|
||||||
|
and training_args.parallel_mode.value == "distributed"
|
||||||
|
):
|
||||||
|
raise ValueError("BAdam with layer-wise mode is not supported in distributed training by now, use ratio mode instead.")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
|
|
|
@ -37,7 +37,7 @@ def init_adapter(
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
if not finetuning_args.pure_bf16:
|
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
|
@ -82,7 +82,7 @@ def init_adapter(
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
if any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||||
if not finetuning_args.pure_bf16:
|
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
@ -162,7 +162,7 @@ def init_adapter(
|
||||||
)
|
)
|
||||||
model = get_peft_model(model, lora_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
if not finetuning_args.pure_bf16:
|
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
|
||||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from ..extras.logging import get_logger
|
||||||
from ..extras.misc import get_current_device, infer_optim_dtype
|
from ..extras.misc import get_current_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_flash_attn2_available
|
from ..extras.packages import is_flash_attn2_available
|
||||||
from ..extras.patches.llama_patch import apply_llama_patch
|
from ..extras.patches.llama_patch import apply_llama_patch
|
||||||
from .utils import QuantizationMethod, add_z3_leaf_module
|
from .utils import QuantizationMethod, add_z3_leaf_module, gradient_checkpointing_enable
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -266,8 +266,9 @@ def _prepare_model_for_training(
|
||||||
else:
|
else:
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
|
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
model.enable_input_require_grads()
|
# model.enable_input_require_grads()
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
|
|
|
@ -135,3 +135,45 @@ def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tok
|
||||||
model.__class__.register_for_auto_class()
|
model.__class__.register_for_auto_class()
|
||||||
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
|
||||||
tokenizer.__class__.register_for_auto_class()
|
tokenizer.__class__.register_for_auto_class()
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||||
|
"""
|
||||||
|
Modification of the original method to enable gradient checkpointing for block-wise optimizer.
|
||||||
|
|
||||||
|
Activates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
|
||||||
|
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradient_checkpointing_kwargs (dict, *optional*):
|
||||||
|
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||||
|
"""
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
if not self.supports_gradient_checkpointing:
|
||||||
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
|
|
||||||
|
if gradient_checkpointing_kwargs is None:
|
||||||
|
gradient_checkpointing_kwargs = {}
|
||||||
|
|
||||||
|
# gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
|
def gradient_checkpointing_func(func, *args, **kwargs):
|
||||||
|
module = func.__self__
|
||||||
|
|
||||||
|
if any([p.requires_grad for p in module.parameters()]):
|
||||||
|
for arg in args:
|
||||||
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
|
arg.requires_grad_(True)
|
||||||
|
|
||||||
|
return checkpoint(func, *args, **kwargs)
|
||||||
|
|
||||||
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
|
if getattr(self, "_hf_peft_config_loaded", False):
|
||||||
|
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||||
|
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||||
|
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||||
|
# the gradients to make sure the gradient flows.
|
||||||
|
self.enable_input_require_grads()
|
|
@ -9,7 +9,8 @@ from transformers import Seq2SeqTrainer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||||
|
from types import MethodType
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
@ -28,6 +29,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.13"):
|
||||||
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||||
|
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
|
|
|
@ -287,12 +287,69 @@ def _create_loraplus_optimizer(
|
||||||
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio))
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
def _create_badam_optimizer(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
|
||||||
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
decay_parameters = list(filter(lambda n: "bias" not in n, get_parameter_names(model, ALL_LAYERNORM_LAYERS)))
|
||||||
|
# filter out the embedding layers when using badam ratio mode
|
||||||
|
if finetuning_args.badam_mode == "ratio":
|
||||||
|
decay_parameters = list(filter(lambda n: "embed" not in n, decay_parameters)) # TODO: make it more general
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if n in decay_parameters],
|
||||||
|
"weight_decay": training_args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in model.named_parameters() if n not in decay_parameters],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||||
|
|
||||||
|
# create BlockOptimizer
|
||||||
|
if finetuning_args.badam_mode == "layer":
|
||||||
|
from badam import BlockOptimizer
|
||||||
|
base_optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
optimizer = BlockOptimizer(base_optimizer=base_optimizer,
|
||||||
|
named_parameters_list=list(model.named_parameters()),
|
||||||
|
block_prefix_list=None,
|
||||||
|
switch_block_every=finetuning_args.switch_block_every,
|
||||||
|
start_block=finetuning_args.start_block,
|
||||||
|
switch_mode=finetuning_args.switch_mode,
|
||||||
|
verbose=finetuning_args.badam_verbose)
|
||||||
|
|
||||||
|
logger.info(f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.switch_mode}, "
|
||||||
|
f"switch block every {finetuning_args.switch_block_every} steps, "
|
||||||
|
f"default start block is {finetuning_args.start_block}")
|
||||||
|
|
||||||
|
elif finetuning_args.badam_mode == "ratio":
|
||||||
|
assert finetuning_args.badam_update_ratio > 0.
|
||||||
|
from badam import BlockOptimizerRatio
|
||||||
|
optimizer = BlockOptimizerRatio(param_groups=optimizer_grouped_parameters,
|
||||||
|
named_parameters_list=list(model.named_parameters()),
|
||||||
|
update_ratio=finetuning_args.badam_update_ratio,
|
||||||
|
mask_mode=finetuning_args.badam_mask_mode,
|
||||||
|
verbose=finetuning_args.badam_verbose,
|
||||||
|
**optimizer_kwargs)
|
||||||
|
|
||||||
|
logger.info(f"Using BAdam optimizer with ratio update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||||
|
f"mask mode is {finetuning_args.badam_mask_mode}")
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
def create_custom_optimzer(
|
def create_custom_optimzer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
) -> Optional["torch.optim.Optimizer"]:
|
) -> Optional["torch.optim.Optimizer"]:
|
||||||
|
if finetuning_args.use_badam:
|
||||||
|
return _create_badam_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue