resolve gradient checkpointing issue.
This commit is contained in:
parent
06c8908d3f
commit
7ecb61822b
|
@ -31,6 +31,5 @@ python ../../../src/train_bash.py \
|
||||||
--use_badam \
|
--use_badam \
|
||||||
--switch_mode descending \
|
--switch_mode descending \
|
||||||
--badam_verbose 2 \
|
--badam_verbose 2 \
|
||||||
--switch_block_every 50 \
|
--switch_block_every 50
|
||||||
--pure_bf16 \
|
|
||||||
|
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -24,6 +24,7 @@ extra_require = {
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
"unsloth": ["torch==2.2.0", "unsloth[cu121-ampere-torch220]"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
|
"badam": ["torch>=2.1.0"],
|
||||||
"vllm": ["vllm>=0.3.3"],
|
"vllm": ["vllm>=0.3.3"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||||
|
|
|
@ -150,30 +150,24 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||||
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
|
||||||
"""
|
"""
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
import functools
|
||||||
|
|
||||||
if not self.supports_gradient_checkpointing:
|
if not self.supports_gradient_checkpointing:
|
||||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
|
|
||||||
if gradient_checkpointing_kwargs is None:
|
if gradient_checkpointing_kwargs is None:
|
||||||
gradient_checkpointing_kwargs = {}
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
# gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
checkpoint = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
def gradient_checkpointing_func(func, *args, **kwargs):
|
def gradient_checkpointing_func(func, *args, **kwargs):
|
||||||
module = func.__self__
|
module = func.__self__
|
||||||
|
|
||||||
if any([p.requires_grad for p in module.parameters()]):
|
if any(p.requires_grad for p in module.parameters()):
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
arg.requires_grad_(True)
|
arg.requires_grad_(True)
|
||||||
|
|
||||||
return checkpoint(func, *args, **kwargs)
|
return checkpoint(func, *args, **kwargs)
|
||||||
|
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
if getattr(self, "_hf_peft_config_loaded", False):
|
|
||||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
|
||||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
|
||||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
|
||||||
# the gradients to make sure the gradient flows.
|
|
||||||
self.enable_input_require_grads()
|
|
|
@ -29,7 +29,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
if version.parse(torch.__version__) >= version.parse("1.13"):
|
if finetuning_args.use_badam:
|
||||||
from badam import clip_grad_norm_for_sparse_tensor
|
from badam import clip_grad_norm_for_sparse_tensor
|
||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue