fix #3316
This commit is contained in:
parent
6d641af703
commit
c9a477322d
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
from enum import Enum, unique
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
@ -129,7 +130,11 @@ def gradient_checkpointing_enable(
|
|||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else:
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
|
|
Loading…
Reference in New Issue