OpenDeltaMirror/opendelta/basemodel.py

783 lines
39 KiB
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
from collections import OrderedDict
from multiprocessing.sharedctypes import Value
import os
from turtle import back
2022-02-14 21:19:03 +08:00
from opendelta.delta_configs import BaseDeltaConfig
2022-10-23 16:42:21 +08:00
from opendelta.utils.inspect import inspect_module_statistics
2022-02-14 21:19:03 +08:00
from opendelta.utils.model_md5 import gen_model_hash
from opendelta.utils.signature import get_arg_names, signature
from typing import Optional, Union
from opendelta.utils.cuda import get_device
from opendelta.utils.name_based_addressing import *
import torch.nn as nn
import torch
from functools import wraps
# from decorator import decorate
from opendelta.utils.decorate import decorate
from opendelta.utils.structure_mapping import transform
from transformers.file_utils import PushToHubMixin
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from opendelta import SaveLoadMixin
from opendelta import logging
from opendelta.utils.structure_mapping import CommonStructureMap
from opendelta.utils.interactive.web import interactive
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
2022-06-07 01:52:32 +08:00
from opendelta.utils.cuda import move_dict_to_cuda
import sys
2022-06-07 01:52:32 +08:00
2022-02-26 09:00:12 +08:00
from opendelta.utils.data_parallel import caller_map
2022-10-23 16:42:21 +08:00
from opendelta.utils.backend import BackendMapping
2022-02-14 21:19:03 +08:00
logger = logging.get_logger(__name__)
def is_leaf_module(module):
r"""Whether the module is a leaf module
"""
return len([n for n,_ in module.named_children()]) == 0
2022-02-14 21:19:03 +08:00
def non_module_param(module: nn.Module):
module_names = [n for n, _ in module.named_modules()]
ret = []
for n, p in module.named_parameters():
if not is_child_key(n, module_names):
ret.append((n,p))
return ret
class DeltaBase(nn.Module, SaveLoadMixin):
2022-04-14 11:22:41 +08:00
r"""This is the base class for all delta models. It provides four simple but effective functionalities
2022-02-14 21:19:03 +08:00
for building the delta model:
2022-04-14 11:22:41 +08:00
#. addressing a module inside the backbone model using a minimal description key.
#. provide the interface for modifying and inserting model which keeps the docs/IO the same as the module
2022-02-14 21:19:03 +08:00
before modification.
2022-04-14 11:22:41 +08:00
#. pass a pseudo input to determine the inter dimension of the delta models.
#. freeze a part of model parameters according to key.
It also provides unified interface for model loading and saving.
2022-02-14 21:19:03 +08:00
Class attributes (overridden by derived classes):
- delta_type (:obj:`str`): the name of the delta modules, used to create the correct :class:`opendelta.AutoDeltaModel`.
2022-04-14 11:22:41 +08:00
- config_class (:class:`BaseDeltaConfig`): The corresponding config model
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`, *required*): backbone model that the delta models are build opon. The modification to the
2022-02-14 21:19:03 +08:00
backbone model are in place.
2022-04-14 11:22:41 +08:00
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules are subjected to update.
2022-02-14 21:19:03 +08:00
.. note::
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules.
2022-04-14 11:22:41 +08:00
exclude_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules starts with these strings will be excluded in modification.
Note that currently only plain text (no regular expression) is supported.
unfrozen_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules that are **not** frozen when freezing the main part of the model.
2022-02-14 21:19:03 +08:00
registraction_name (:obj:`str`, *optional*, default to ``"deltas"``): The root name of the delta models when
2022-04-14 11:22:41 +08:00
attached to the backbone model.
2022-02-14 21:19:03 +08:00
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): Whether use the common structure mapping to specify the
modified_modules. i.e., if common_structure=True, then we use a common ["attn"] for attention module in different models.
2022-04-14 11:22:41 +08:00
We DO NOT recommend manually set ``common_structure`` to ``true`` by yourself unless you are using delta
among multiple backbones and don't want to modify the code.
2022-02-14 21:19:03 +08:00
interactive_modify (:obj:`bool` or :obj:`int`, *optional*, default to :obj:`None`): Whether to use interactive modification.
By setting to :obj:`int` can specify the port of web server.
"""
delta_type = ""
default_modified_modules = []
2022-04-14 11:22:41 +08:00
default_exclude_modules = ["lm_head"]
2022-02-14 21:19:03 +08:00
config_class = BaseDeltaConfig
default_unfrozen_modules = ["deltas"]
_need_pseudo_data = True
2022-10-23 16:42:21 +08:00
_supported_backends = ['hf']
2022-04-14 11:22:41 +08:00
def __init__(self,
2022-02-14 21:19:03 +08:00
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
2022-04-14 11:22:41 +08:00
exclude_modules: Optional[List[str]] = None,
2022-02-14 21:19:03 +08:00
unfrozen_modules: Optional[List[str]] = None,
interactive_modify: Optional[Union[bool, int]] = False,
2022-03-19 15:04:42 +08:00
common_structure: Optional[bool] = False,
2022-10-23 16:42:21 +08:00
backend: Optional[str]= "hf", # select from ["hf", "bmt"]
2022-02-14 21:19:03 +08:00
):
nn.Module.__init__(self)
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
# to the modules of the delta model.
self.__dict__["backbone_model"] = backbone_model
2022-04-14 11:22:41 +08:00
if modified_modules is None and exclude_modules is None:
2022-02-14 21:19:03 +08:00
if interactive_modify:
if isinstance(interactive_modify, bool) and interactive_modify==True:
self.modified_modules = interactive(backbone_model)
else:
self.modified_modules = interactive(backbone_model, port=interactive_modify)
self.common_structure = False
2022-06-07 01:52:32 +08:00
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
else:
self.modified_modules = self.default_modified_modules
self.common_structure = True
2022-04-14 11:22:41 +08:00
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
else:
if interactive_modify:
2022-04-14 11:22:41 +08:00
raise ValueError("Use modified_modules(or exclude modules) and interactive_modify at the same time is not supported")
if modified_modules is not None:
self.modified_modules = modified_modules
else:
self.modified_modules = self.default_modified_modules
if exclude_modules is not None:
self.exclude_modules = exclude_modules
else:
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
self.common_structure = common_structure
if self.common_structure:
self.structure_mapping = CommonStructureMap(self.backbone_model)
2022-02-14 21:19:03 +08:00
else:
self.structure_mapping = None
if unfrozen_modules is None:
self.unfrozen_modules = self.default_unfrozen_modules
if self.common_structure and self.structure_mapping is None:
raise RuntimeError("Using common structure but the structure mapping is None")
2022-10-23 16:42:21 +08:00
if backend not in self._supported_backends:
raise RuntimeError("Currently, backend `{}` is not supported for `{}`".format(backend, self.__class__.__name__))
self.backend = backend
self.backend_mapping = BackendMapping(backend)
2022-04-14 11:22:41 +08:00
2022-10-14 23:15:38 +08:00
def forward(self, *args, **kwargs) -> RuntimeError:
2022-04-14 11:22:41 +08:00
r"""
2022-02-14 21:19:03 +08:00
.. warning::
Removed method. As the model is a delta model, which should be attached to a backbone model \
and can't forward any data by itself. Please using the backbone model's forward function \
after attach the delta model to the backbone.
"""
raise RuntimeError("This is a delta model, which should be attached to a backbone model \
and can't forward any data by itself. Please using the backbone model's forward function \
after attach the delta model to the backbone. ")
@classmethod
def from_config(cls, config: Union[BaseDeltaConfig, dict], backbone_model: nn.Module, check_hash=True, **kwargs):
r"""Initialize a delta model from a config object or a dict containing the configs. To temperarily change
a value in the config, pass it through kwargs. If the config has a backbone model's hash, which means it is
a finetuned delta model's config, then we will compare the hash in the config and the newly caculated to ensure
the finedtuned delta model is trained on the passed backbone_model. Pass ``check_hash=False`` to disable the
checking.
Args:
2022-04-14 11:22:41 +08:00
config (:obj:`BaseDeltaConfig` or :obj:`dict`) A config object or a dict that contains the necessary value to
2022-02-14 21:19:03 +08:00
initialize the delta model.
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`) A pytorch module that will be pass into the delta model as the backbone
2022-02-14 21:19:03 +08:00
model. modifications will be made in place in the backbone model.
2022-04-14 11:22:41 +08:00
check_hash (:obj:`bool`, default to ``True``) Whether to check hash of the backbone model and the config's
backbone hash.
2022-02-14 21:19:03 +08:00
kwargs: Any configurations that are passed to update the config object. #TODO unit test needed.
"""
supported_keys = get_arg_names(cls.__init__) + get_arg_names(DeltaBase.__init__)
config_dict = config.to_dict()
for key in list(config_dict.keys()):
if key not in supported_keys:
config_dict.pop(key)
return cls(backbone_model, **config_dict)
2022-04-14 11:22:41 +08:00
def add_all_delta_to_backbone(self,
backbone: nn.Module,
2022-02-14 21:19:03 +08:00
modified_modules: List[str],
) -> nn.Module:
r"""The main function to add delta models to the backbone model based on the :obj:`modified_modules`.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`, *required*) backbone model that the delta models are build opon. The
2022-02-14 21:19:03 +08:00
modification to the backbone model are in place.
2022-04-14 11:22:41 +08:00
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`) The modules are subjected to update.
2022-02-14 21:19:03 +08:00
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules.
Returns:
:obj:`nn.Module` The modified backbone model.
"""
self.plm_total_params = sum(p.numel() for p in backbone.parameters())
# create a new key list to avoid recursion.
2022-04-14 11:22:41 +08:00
backbone_key_list = [key for key, _ in backbone.named_modules()]
2022-02-14 21:19:03 +08:00
for key in backbone_key_list:
2022-10-14 23:15:38 +08:00
if self.find_key(key, modified_modules):
2022-02-14 21:19:03 +08:00
self.update_module(backbone, key)
if self._need_pseudo_data:
2022-10-14 23:15:38 +08:00
self._pseudo_data_to_instantiate(backbone)
2022-02-14 21:19:03 +08:00
# mark the paratmers that are the delta parameters for easily displaying the delta_paramters.
self.mark_as_delta()
return backbone
def _pseudo_data_to_instantiate(self, backbone: Optional[nn.Module]=None):
if self.structure_mapping is None:
self._pseudo_data_to_instantiate_module(backbone)
else:
for key in self.structure_mapping.matched_pairs:
2022-10-14 23:15:38 +08:00
if key == "":
submodule = backbone
else:
_, _, submodule = self.find_module(backbone, key)
self._pseudo_data_to_instantiate_module(submodule)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def mark_as_delta(self, module: nn.Module=None,):
r"""[NODOC] Mark :obj:`module`'s all parameters as delta parameters by setting a ``_is_delta`` attribute to each of them.
2022-04-14 11:22:41 +08:00
Generally, it is used after creating the delta modules. By leaving module to :obj:`None`, it will mark all the parameters in the
2022-02-14 21:19:03 +08:00
delta model as ``_is_delta``.
Args:
module (:obj:`nn.Module`): The module to mark as delta.
"""
if module is None:
module=self # all the parameters in the delta model.
for p in module.parameters():
setattr(p, "_is_delta", True)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module(self, module: nn.Module, key: str):
2022-04-14 11:22:41 +08:00
r"""Update a module specified by :obj:`key`. The method is reimplemented in each specific delta model.
2022-02-14 21:19:03 +08:00
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def freeze_module(self,
2022-04-14 11:22:41 +08:00
module: Optional[nn.Module] = None,
exclude: Optional[List[str]] = None,
set_state_dict: Optional[bool]=True,
2022-02-14 21:19:03 +08:00
):
r"""Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ``_is_delta`` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model, (e.g., bias term)
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The module of which some parts are frozen.
2022-04-14 11:22:41 +08:00
If left with :obj:`None`, the function will the self.backbone_model as the module to be frozen.
exclude (:obj:`List[str]`, *optional*, default to ``["deltas"]``): The parameters that don't need to
2022-02-14 21:19:03 +08:00
be freezed. Default to all the delta parameters.
set_state_dict (:obj:`bool`, *optional*, default to :obj:`True`): Whether setting the backbone model's state
dict to all the parameters that still need grad.
2022-04-14 11:22:41 +08:00
prefix (:obj:`str`, *optional*, default to ``""``): A parameters that are used for recursive frozen.
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ``""``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if exclude is None:
exclude = self.unfrozen_modules
if module is None:
module = self.backbone_model
self._freeze_module_recursive(module, exclude, "") # modify the active state dict that still need grad
if set_state_dict:
self.set_active_state_dict(module)
def _freeze_module_recursive(self,
2022-04-14 11:22:41 +08:00
module: Optional[nn.Module] = None,
2022-02-14 21:19:03 +08:00
exclude: Optional[List[str]] = None,
prefix=""):
r"""[NODOC] Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ``_is_delta`` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model, (e.g., bias term)
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The module of which some parts are frozen.
2022-04-14 11:22:41 +08:00
If left with :obj:`None`, the function will the self.backbone_model as the module to be frozen.
exclude (:obj:`List[str]`, *optional*, default to ``["deltas"]``): The parameters that don't need to
2022-02-14 21:19:03 +08:00
be freezed. Default to all the delta parameters.
set_state_dict (:obj:`bool`, *optional*, default to :obj:`True`): Whether setting the backbone model's state
dict to all the parameters that still need grad.
2022-04-14 11:22:41 +08:00
prefix (:obj:`str`, *optional*, default to ``""``): A parameters that are used for recursive frozen.
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ``""``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if is_leaf_module(module):
for n, p in module.named_parameters():
2022-10-10 09:27:06 +08:00
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude):
2022-02-14 21:19:03 +08:00
continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
2022-04-14 11:22:41 +08:00
return
2022-02-14 21:19:03 +08:00
else:
for n, c in module.named_children():
2022-10-10 09:27:06 +08:00
next_prefix = n if prefix == "" else ".".join([prefix,n])
if self.find_key(next_prefix, exclude): # if found, untouch the parameters
2022-02-14 21:19:03 +08:00
continue
else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
2022-10-10 09:27:06 +08:00
self._freeze_module_recursive(c, exclude=exclude, prefix=next_prefix)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
2022-02-15 22:43:28 +08:00
def find_key(self, key: str, target_list: List[str]):
2022-04-14 11:22:41 +08:00
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Args:
2022-02-15 17:59:53 +08:00
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
2022-02-14 21:19:03 +08:00
E.g., model.encoder.layer.0.attention
2022-02-15 17:59:53 +08:00
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Returns:
2022-02-14 21:19:03 +08:00
:obj:`bool` True if the key matchs the target list.
"""
2022-04-14 11:22:41 +08:00
for x in self.exclude_modules:
if key.startswith(x): # start with the excluded key
return False
virtual_key, in_virtual_order = None, None
2022-10-14 23:15:38 +08:00
if self.structure_mapping is not None:
key, virtual_key, in_virtual_order = self.structure_mapping.transform(key, strict=False)
# currently in_virtual_order not in use, it means that if the common structure designate adding adapter to FFN, it will be add to all submodule of FFN.
2022-02-14 21:19:03 +08:00
if not key:
return False
2022-10-14 23:15:38 +08:00
if virtual_key is None:
2022-02-15 22:43:28 +08:00
return endswith_in(key, target_list)
2022-10-14 23:15:38 +08:00
else:
return endswith_in(key, target_list) or endswith_in(virtual_key, target_list)
2022-02-14 21:19:03 +08:00
def _pseudo_data_to_instantiate_module(self, module: Optional[nn.Module]=None):
r"""Some delta model requires a pseudo-data be passed through the model to understand the dimensionality of each tensor in the computation graph.
2022-02-14 21:19:03 +08:00
(1) The model in the Huggingface Transformers library usually has the so-called `dummy_inputs`. We will make use of it.
(2) If the model does not have `dummy_inputs`, we will try to create it and throw a warning.
(3) If we encounter an error in (2), we will suggest you to create it by passing the dummy_inputs variable.
2022-02-14 21:19:03 +08:00
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
"""
if module is None:
module = self.backbone_model
2022-06-07 01:52:32 +08:00
device = get_device(module)
_auto_dummy = False
2022-02-14 21:19:03 +08:00
try:
dummy_inputs = module.dummy_inputs
2022-06-07 01:52:32 +08:00
dummy_inputs = move_dict_to_cuda(dummy_inputs, device)
2022-02-14 21:19:03 +08:00
except AttributeError:
logger.warning(f"No `dummy_inputs` attribute in {module.__class__.__name__} , automatically create `dummy_inputs`. Very likely to encounter error. To set dummy_inputs for your model, please use: `setattr(backbone_model, 'dummy_inputs', some_dummy_inputs)` before initializing `{self.__class__.__name__}`")
_auto_dummy = True
pass
if _auto_dummy:
_most_simple_input = torch.tensor([[0,0]]).to(device)
2022-02-14 21:19:03 +08:00
if "decoder_input_ids" in signature(module.forward).args:
dummy_inputs = {"input_ids": _most_simple_input, "decoder_input_ids": _most_simple_input}
2022-02-14 21:19:03 +08:00
else:
dummy_inputs = {"input_ids": _most_simple_input}
_auto_dummy_fail = False
try:
module(**dummy_inputs)
2022-10-23 16:42:21 +08:00
except Exception as e:
_auto_dummy_fail = True
2022-10-23 16:42:21 +08:00
if _auto_dummy_fail and _auto_dummy:
raise AttributeError(f"str({e})\n\tThe {self.__class__.__name__} requires a dummy_inputs to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n\t The {module.__class__.__name__} Class has no dummy_inputs, and automatically created dummy_inputs failed.\n\t Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail.")
2022-02-14 21:19:03 +08:00
def trainable_parameters_names(self, module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the trainable parameter's name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the trainable paramemters' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
:obj:`List[str]`
"""
if module is None:
module = self.backbone_model
return [n for n,p in module.named_parameters() if p.requires_grad]
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def frozen_parameters_names(self, module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the frozen parameters' name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the frozen paramemters' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
:obj:`List[str]`
"""
if module is None:
module = self.backbone_model
return [n for n,p in module.named_parameters() if not p.requires_grad]
def trainable_parameters(self,module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the frozen parameters in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the frozen paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
return [p for n,p in module.named_parameters() if p.requires_grad]
def num_trainable_parameters(self, module: Optional[nn.Module]=None):
2022-04-14 11:22:41 +08:00
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
pnum_tot = 0
for param in module.parameters():
if param.requires_grad:
pnum_tot += param.numel()
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def num_total_parameters(self, module: Optional[nn.Module]=None):
2022-04-14 11:22:41 +08:00
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
pnum_tot = 0
for param in module.parameters():
pnum_tot += param.numel()
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def find_module(self, root_module: nn.Module, key:str):
r"""Find the module using a key and the root module. Return both the parent reference, the child name and reference.
Args:
root_module (:obj:`root_module`): The root_module to find the sub module in
2022-04-14 11:22:41 +08:00
key (:obj:`str`): The relative key to the root module.
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
(:obj:`nn.Module`, :obj:`str`, :obj:`nn.Module`):
* A reference to the parent module of the target module, mainly for substuting the target module.
2022-02-14 21:19:03 +08:00
* The key of the target module relevant to its parent module
* Target module.
"""
sub_keys = key.split(".")
parent_module = root_module
for sub_key in sub_keys[:-1]:
parent_module = getattr(parent_module, sub_key)
module = getattr(parent_module, sub_keys[-1])
return parent_module, sub_keys[-1], module
def _register_delta_infos(self, parent_module, _delta_info):
r"""Register the delta infomation.
Automatically incrementing the suffix for repeated delta_names
"""
_delta_infos = getattr(parent_module, "_delta_infos", [])
if len(_delta_infos) > 0: # check if duplicated name
list_of_deltas = [d['delta_name'] for d in _delta_infos]
cur_name = _delta_info['delta_name']
if cur_name in list_of_deltas:
cur_name = cur_name + "_1"
counter = 1
while cur_name in list_of_deltas:
counter += 1
cur_name = cur_name.split("_")[0] + "_"+str(counter)
_delta_info["delta_name"] = cur_name
_delta_infos.append(_delta_info)
setattr(parent_module, "_delta_infos", _delta_infos)
def replace_module(self,
2022-04-14 11:22:41 +08:00
parent_module: nn.Module,
2022-02-14 21:19:03 +08:00
child_name: str,
child_module: nn.Module,
new_module: nn.Module,
delta_name: Optional[str] = "delta",
):
2022-04-14 11:22:41 +08:00
r"""Replace a module's child module with the new_module(a delta module). Used by delta method based on direct
2022-02-14 21:19:03 +08:00
replacement, such as :class:`opendelta.delta_modules.lora.LoraModel`.
Args:
parent_module (:obj:`nn.Module`): The parent module of the replacement.
child_name (:obj:`str`): The chird module's name, i.e., parent_module.child_name give us child_module
child_module (:obj:`nn.Module`): The original child module.
new_module (:obj:`nn.Module`): The delta module.
delta_name (:obj:`str`, *optional*, default ot ``delta``): The name of the delta module, used for recording.
parent_module.delta_name WILL NOT give you the delta module.
"""
self.delta_modules.append(new_module)
setattr(parent_module, child_name, new_module)
# register delta info
2022-04-14 11:22:41 +08:00
_delta_info = {"method": "replace",
"delta_module": new_module,
2022-02-14 21:19:03 +08:00
"child_name": child_name,
"org_module": child_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=parent_module,
_delta_info = _delta_info,
)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module(self, module: nn.Module):
r"""Modify the inside parameteres of a module. This method will be reimplemented in different
derived class if needed.
"""
raise NotImplementedError
2022-10-17 14:46:30 +08:00
def insert_module(self, module, method='sequential', delta_module=None, delta_name='delta', strict=False, _delta_info=None):
2022-04-14 11:22:41 +08:00
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
2022-02-14 21:19:03 +08:00
function of the original module to firstly pass the arguments into the new module's forward function and then pass
2022-04-14 11:22:41 +08:00
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
2022-02-14 21:19:03 +08:00
When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
2022-04-14 11:22:41 +08:00
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
2022-02-14 21:19:03 +08:00
original delta is passed through ``_delta_info``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
2022-02-26 09:00:12 +08:00
_delta_info = {"method": method,
2022-02-14 21:19:03 +08:00
"delta_module": delta_module,
2022-02-20 17:23:31 +08:00
"delta_name": delta_name,
2022-02-14 21:19:03 +08:00
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
2022-02-20 17:23:31 +08:00
delta_name = _delta_info["delta_name"]
2022-02-14 21:19:03 +08:00
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
2022-04-14 11:22:41 +08:00
2022-02-26 09:00:12 +08:00
if _delta_info["method"] in caller_map.keys():
caller = caller_map[_delta_info["method"]]
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
else:
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
2022-02-26 09:00:12 +08:00
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function.
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
original delta is passed through ``_delta_info``.
"""
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
2022-02-14 21:19:03 +08:00
2022-02-20 17:23:31 +08:00
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
2022-04-14 11:22:41 +08:00
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the delta model's forward function and set
2022-02-14 21:19:03 +08:00
aside the calculation result. Then combine it with the calculation result output from the backbone module.
When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function.
2022-02-20 17:23:31 +08:00
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
2022-04-14 11:22:41 +08:00
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
2022-02-20 17:23:31 +08:00
original delta is passed through ``_delta_info``.
2022-02-14 21:19:03 +08:00
"""
2022-02-20 17:23:31 +08:00
2022-02-26 09:00:12 +08:00
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
2022-02-20 17:23:31 +08:00
2022-02-14 21:19:03 +08:00
def set_active_state_dict(self, module: nn.Module):
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
Args:
module (:obj:`nn.Module`): The module modified. The modification is in-place.
"""
def _caller(_org_func, includes, *args, **kwargs):
state_dict = _org_func(*args, **kwargs)
keys = list(state_dict.keys())
for n in keys:
if n not in includes:
state_dict.pop(n)
return state_dict
includes = self.trainable_parameters_names(module) # use excludes will have trouble when the model have shared weights
if hasattr(module.state_dict, "__wrapped__"):
2022-03-13 01:21:55 +08:00
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended? Do you freeze the parameters twice?")
2022-02-14 21:19:03 +08:00
module.state_dict = decorate(module.state_dict, _caller, extras=(includes,), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def _load_state_dict_into_backbone(self, backbone_model: nn.Module = None, state_dict: dict = {}):
r"""[NODOC]
"""
if backbone_model is None:
backbone_model = self.backbone_model
self.backbone_model.load_state_dict(state_dict, strict=False)
def create_config_from_model(self, ):
r"""[NODOC] If the delta model was built by directly passing arguments, instead of passing a config object.
create the config of the delta model for saving the delta model.
"""
# common_attributes
config = self.config_class()
config_keys = signature(config.__init__)[0] + signature(super(self.config_class, config).__init__)[0]
for key in config_keys:
val = getattr(self, key) if hasattr(self, key) else None
setattr(config, key, val)
config.delta_type = self.delta_type
self.config = config
2022-04-14 11:22:41 +08:00
2022-03-20 03:09:00 +08:00
def log(self, module=None, delta_ratio=True, trainable_ratio=True, visualization=True, cuda_memory=True):
2022-04-14 11:22:41 +08:00
r"""Log and visualize the result of applying delta.
2022-02-14 21:19:03 +08:00
Possible Options are ``trainable_ratio``,
``visualization``, ``delta_ratio``.
Args:
delta_ratio (:obj:`bool`, *optional*): Whether computing the ratio of parameters in the delta modules.
trainable_ratio (:obj:`bool`, *optional*): Whether computing the ratio of trainable parameters.
visualization (:obj:`bool`, *optional*): Whether visualize the parameter information of the modified backbone.
"""
if module is None:
module = self.backbone_model
if visualization:
from opendelta import Visualization
Visualization(module).structure_graph()
2022-06-08 16:55:02 +08:00
2022-10-23 16:42:21 +08:00
self.stat = inspect_module_statistics(module, verbose=False)
2022-02-14 21:19:03 +08:00
if trainable_ratio:
2022-10-23 16:42:21 +08:00
logger.info("Trainable Ratio: {}/{}={:.6f}%".format(self.stat['trainable_parameters'], self.stat['total_parameters'], self.stat['trainable_ratio']*100))
2022-02-14 21:19:03 +08:00
if delta_ratio:
2022-10-23 16:42:21 +08:00
logger.info("Delta Parameter Ratio: {}/{}={:.6f}%".format(self.stat['delta_parameters'], self.stat['total_parameters'],self.stat['delta_ratio']*100))
2022-03-20 03:09:00 +08:00
if cuda_memory:
2022-06-08 16:55:02 +08:00
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# Two functions for plug and remove the delta model.
2022-03-20 10:48:49 +08:00
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
2022-02-14 21:19:03 +08:00
r"""Reattach the delta modules to the backbone. Note that this method can not be used to create new delta modules.
2022-04-14 11:22:41 +08:00
Instead, a :meth:`DeltaBase.detach` should precede this method.
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
module (:obj:`object`, *optional*, default to :obj:`None`): The backbone module that we
2022-02-14 21:19:03 +08:00
reattach the deltas to.
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None:
module = self.backbone_model
for name, submodule in module.named_modules():
if hasattr(submodule, "_delta_infos"):
2022-04-14 11:22:41 +08:00
_delta_infos = getattr(submodule, "_delta_infos")
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos:
if _delta_info['delta_belong'] is not self:
continue
if _delta_info["state"] == "on":
continue
if _delta_info['method'] == "replace":
setattr(submodule, _delta_info["child_name"], _delta_info['delta_module'])
elif _delta_info['method'] == "insert_sequential":
2022-04-14 11:22:41 +08:00
self.insert_sequential_module(module=submodule,
2022-02-14 21:19:03 +08:00
_delta_info=_delta_info)
2022-04-18 23:28:13 +08:00
elif _delta_info['method'] == "insert_parallel":
self.insert_parallel_module(module=submodule,
_delta_info=_delta_info)
2022-02-14 21:19:03 +08:00
else:
raise NotImplementedError
2022-04-14 11:22:41 +08:00
_delta_info['state'] = "on"
2022-03-20 10:48:49 +08:00
if reset_state_dict:
self.set_active_state_dict(module)
2022-02-14 21:19:03 +08:00
2022-03-20 10:48:49 +08:00
def detach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
2022-02-14 21:19:03 +08:00
r"""Detach the delta module from the backbone. The delta module is not deleted, but temporarily turned off.
Use :meth:`DeltaBase.attach` to reattach the delta model to the backbone.
Args:
2022-04-14 11:22:41 +08:00
module (:obj:`object`, *optional*, default to :obj:`None`): The backbone module that we
2022-02-14 21:19:03 +08:00
detached the deltas from.
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None:
module = self.backbone_model
for name, submodule in module.named_modules():
if hasattr(submodule, "_delta_infos"):
2022-04-14 11:22:41 +08:00
_delta_infos = getattr(submodule, "_delta_infos")
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos:
if _delta_info['delta_belong'] is not self:
continue
if _delta_info["state"] == "off":
continue
if _delta_info['method'] == "replace":
setattr(submodule, _delta_info["child_name"], _delta_info['org_module'])
2022-10-17 16:44:44 +08:00
elif _delta_info['method'] in ["sequential", "before", "after", "parallel"]:
2022-04-18 23:28:13 +08:00
if hasattr(submodule.forward, "__wrapped__"):
submodule.forward = submodule.forward.__wrapped__
delattr(submodule, _delta_info["delta_name"])
else:
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
2022-02-14 21:19:03 +08:00
else:
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
_delta_info['state'] = "off"
2022-03-20 10:48:49 +08:00
if reset_state_dict:
try:
module.state_dict = module.state_dict.__wrapped__
except AttributeError:
pass
2022-04-14 11:22:41 +08:00