OpenDeltaMirror/opendelta/basemodel.py

845 lines
41 KiB
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
from collections import OrderedDict
from multiprocessing.sharedctypes import Value
import os
from turtle import back
2022-02-14 21:19:03 +08:00
from opendelta.delta_configs import BaseDeltaConfig
from opendelta.utils.model_md5 import gen_model_hash
from opendelta.utils.signature import get_arg_names, signature
from typing import Optional, Union
from opendelta.utils.cuda import get_device
from opendelta.utils.name_based_addressing import *
import torch.nn as nn
import torch
from functools import wraps
# from decorator import decorate
from opendelta.utils.decorate import decorate
from opendelta.utils.structure_mapping import transform
from transformers.file_utils import PushToHubMixin
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from opendelta import SaveLoadMixin
from opendelta import logging
from opendelta.utils.structure_mapping import CommonStructureMap
from opendelta.utils.interactive.web import interactive
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
2022-06-07 01:52:32 +08:00
from opendelta.utils.cuda import move_dict_to_cuda
import sys
2022-06-07 01:52:32 +08:00
2022-02-14 21:19:03 +08:00
logger = logging.get_logger(__name__)
def is_leaf_module(module):
r"""Whether the module is a leaf module
"""
return len([n for n,_ in module.named_children()]) == 0
2022-02-14 21:19:03 +08:00
def non_module_param(module: nn.Module):
module_names = [n for n, _ in module.named_modules()]
ret = []
for n, p in module.named_parameters():
if not is_child_key(n, module_names):
ret.append((n,p))
return ret
class DeltaBase(nn.Module, SaveLoadMixin):
2022-04-14 11:22:41 +08:00
r"""This is the base class for all delta models. It provides four simple but effective functionalities
2022-02-14 21:19:03 +08:00
for building the delta model:
2022-04-14 11:22:41 +08:00
#. addressing a module inside the backbone model using a minimal description key.
#. provide the interface for modifying and inserting model which keeps the docs/IO the same as the module
2022-02-14 21:19:03 +08:00
before modification.
2022-04-14 11:22:41 +08:00
#. pass a pseudo input to determine the inter dimension of the delta models.
#. freeze a part of model parameters according to key.
It also provides unified interface for model loading and saving.
2022-02-14 21:19:03 +08:00
Class attributes (overridden by derived classes):
- delta_type (:obj:`str`): the name of the delta modules, used to create the correct :class:`opendelta.AutoDeltaModel`.
2022-04-14 11:22:41 +08:00
- config_class (:class:`BaseDeltaConfig`): The corresponding config model
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`, *required*): backbone model that the delta models are build opon. The modification to the
2022-02-14 21:19:03 +08:00
backbone model are in place.
2022-04-14 11:22:41 +08:00
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules are subjected to update.
2022-02-14 21:19:03 +08:00
.. note::
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules.
2022-04-14 11:22:41 +08:00
exclude_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules starts with these strings will be excluded in modification.
Note that currently only plain text (no regular expression) is supported.
unfrozen_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules that are **not** frozen when freezing the main part of the model.
2022-02-14 21:19:03 +08:00
registraction_name (:obj:`str`, *optional*, default to ``"deltas"``): The root name of the delta models when
2022-04-14 11:22:41 +08:00
attached to the backbone model.
2022-02-14 21:19:03 +08:00
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): Whether use the common structure mapping to specify the
modified_modules. i.e., if common_structure=True, then we use a common ["attn"] for attention module in different models.
2022-04-14 11:22:41 +08:00
We DO NOT recommend manually set ``common_structure`` to ``true`` by yourself unless you are using delta
among multiple backbones and don't want to modify the code.
2022-02-14 21:19:03 +08:00
interactive_modify (:obj:`bool` or :obj:`int`, *optional*, default to :obj:`None`): Whether to use interactive modification.
By setting to :obj:`int` can specify the port of web server.
"""
delta_type = ""
default_modified_modules = []
2022-04-14 11:22:41 +08:00
default_exclude_modules = ["lm_head"]
2022-02-14 21:19:03 +08:00
config_class = BaseDeltaConfig
default_unfrozen_modules = ["deltas"]
_need_pseudo_data = True
2022-04-14 11:22:41 +08:00
def __init__(self,
2022-02-14 21:19:03 +08:00
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
2022-04-14 11:22:41 +08:00
exclude_modules: Optional[List[str]] = None,
2022-02-14 21:19:03 +08:00
unfrozen_modules: Optional[List[str]] = None,
interactive_modify: Optional[Union[bool, int]] = False,
2022-03-19 15:04:42 +08:00
common_structure: Optional[bool] = False,
2022-02-14 21:19:03 +08:00
):
nn.Module.__init__(self)
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
# to the modules of the delta model.
self.__dict__["backbone_model"] = backbone_model
2022-04-14 11:22:41 +08:00
if modified_modules is None and exclude_modules is None:
2022-02-14 21:19:03 +08:00
if interactive_modify:
if isinstance(interactive_modify, bool) and interactive_modify==True:
self.modified_modules = interactive(backbone_model)
else:
self.modified_modules = interactive(backbone_model, port=interactive_modify)
self.common_structure = False
2022-06-07 01:52:32 +08:00
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
else:
self.modified_modules = self.default_modified_modules
self.common_structure = True
2022-04-14 11:22:41 +08:00
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
else:
if interactive_modify:
2022-04-14 11:22:41 +08:00
raise ValueError("Use modified_modules(or exclude modules) and interactive_modify at the same time is not supported")
if modified_modules is not None:
self.modified_modules = modified_modules
else:
self.modified_modules = self.default_modified_modules
if exclude_modules is not None:
self.exclude_modules = exclude_modules
else:
self.exclude_modules = self.default_exclude_modules
2022-02-14 21:19:03 +08:00
self.common_structure = common_structure
if self.common_structure:
self.structure_mapping = CommonStructureMap(self.backbone_model)
2022-02-14 21:19:03 +08:00
else:
self.structure_mapping = None
if unfrozen_modules is None:
self.unfrozen_modules = self.default_unfrozen_modules
if self.common_structure and self.structure_mapping is None:
raise RuntimeError("Using common structure but the structure mapping is None")
2022-04-14 11:22:41 +08:00
2022-10-14 23:15:38 +08:00
def forward(self, *args, **kwargs) -> RuntimeError:
2022-04-14 11:22:41 +08:00
r"""
2022-02-14 21:19:03 +08:00
.. warning::
Removed method. As the model is a delta model, which should be attached to a backbone model \
and can't forward any data by itself. Please using the backbone model's forward function \
after attach the delta model to the backbone.
"""
raise RuntimeError("This is a delta model, which should be attached to a backbone model \
and can't forward any data by itself. Please using the backbone model's forward function \
after attach the delta model to the backbone. ")
@classmethod
def from_config(cls, config: Union[BaseDeltaConfig, dict], backbone_model: nn.Module, check_hash=True, **kwargs):
r"""Initialize a delta model from a config object or a dict containing the configs. To temperarily change
a value in the config, pass it through kwargs. If the config has a backbone model's hash, which means it is
a finetuned delta model's config, then we will compare the hash in the config and the newly caculated to ensure
the finedtuned delta model is trained on the passed backbone_model. Pass ``check_hash=False`` to disable the
checking.
Args:
2022-04-14 11:22:41 +08:00
config (:obj:`BaseDeltaConfig` or :obj:`dict`) A config object or a dict that contains the necessary value to
2022-02-14 21:19:03 +08:00
initialize the delta model.
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`) A pytorch module that will be pass into the delta model as the backbone
2022-02-14 21:19:03 +08:00
model. modifications will be made in place in the backbone model.
2022-04-14 11:22:41 +08:00
check_hash (:obj:`bool`, default to ``True``) Whether to check hash of the backbone model and the config's
backbone hash.
2022-02-14 21:19:03 +08:00
kwargs: Any configurations that are passed to update the config object. #TODO unit test needed.
"""
supported_keys = get_arg_names(cls.__init__) + get_arg_names(DeltaBase.__init__)
config_dict = config.to_dict()
for key in list(config_dict.keys()):
if key not in supported_keys:
config_dict.pop(key)
return cls(backbone_model, **config_dict)
2022-04-14 11:22:41 +08:00
def add_all_delta_to_backbone(self,
backbone: nn.Module,
2022-02-14 21:19:03 +08:00
modified_modules: List[str],
) -> nn.Module:
r"""The main function to add delta models to the backbone model based on the :obj:`modified_modules`.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`nn.Module`, *required*) backbone model that the delta models are build opon. The
2022-02-14 21:19:03 +08:00
modification to the backbone model are in place.
2022-04-14 11:22:41 +08:00
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`) The modules are subjected to update.
2022-02-14 21:19:03 +08:00
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules.
Returns:
:obj:`nn.Module` The modified backbone model.
"""
self.plm_total_params = sum(p.numel() for p in backbone.parameters())
# create a new key list to avoid recursion.
2022-04-14 11:22:41 +08:00
backbone_key_list = [key for key, _ in backbone.named_modules()]
2022-02-14 21:19:03 +08:00
for key in backbone_key_list:
2022-10-14 23:15:38 +08:00
print(key)
if self.find_key(key, modified_modules):
print("found!")
2022-02-14 21:19:03 +08:00
self.update_module(backbone, key)
if self._need_pseudo_data:
2022-10-14 23:15:38 +08:00
self._pseudo_data_to_instantiate(backbone)
2022-02-14 21:19:03 +08:00
# mark the paratmers that are the delta parameters for easily displaying the delta_paramters.
self.mark_as_delta()
return backbone
def _pseudo_data_to_instantiate(self, backbone: Optional[nn.Module]=None):
if self.structure_mapping is None:
self._pseudo_data_to_instantiate_module(backbone)
else:
for key in self.structure_mapping.matched_pairs:
2022-10-14 23:15:38 +08:00
if key == "":
submodule = backbone
else:
_, _, submodule = self.find_module(backbone, key)
self._pseudo_data_to_instantiate_module(submodule)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def mark_as_delta(self, module: nn.Module=None,):
r"""[NODOC] Mark :obj:`module`'s all parameters as delta parameters by setting a ``_is_delta`` attribute to each of them.
2022-04-14 11:22:41 +08:00
Generally, it is used after creating the delta modules. By leaving module to :obj:`None`, it will mark all the parameters in the
2022-02-14 21:19:03 +08:00
delta model as ``_is_delta``.
Args:
module (:obj:`nn.Module`): The module to mark as delta.
"""
if module is None:
module=self # all the parameters in the delta model.
for p in module.parameters():
setattr(p, "_is_delta", True)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module(self, module: nn.Module, key: str):
2022-04-14 11:22:41 +08:00
r"""Update a module specified by :obj:`key`. The method is reimplemented in each specific delta model.
2022-02-14 21:19:03 +08:00
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def freeze_module(self,
2022-04-14 11:22:41 +08:00
module: Optional[nn.Module] = None,
exclude: Optional[List[str]] = None,
set_state_dict: Optional[bool]=True,
2022-02-14 21:19:03 +08:00
):
r"""Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ``_is_delta`` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model, (e.g., bias term)
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The module of which some parts are frozen.
2022-04-14 11:22:41 +08:00
If left with :obj:`None`, the function will the self.backbone_model as the module to be frozen.
exclude (:obj:`List[str]`, *optional*, default to ``["deltas"]``): The parameters that don't need to
2022-02-14 21:19:03 +08:00
be freezed. Default to all the delta parameters.
set_state_dict (:obj:`bool`, *optional*, default to :obj:`True`): Whether setting the backbone model's state
dict to all the parameters that still need grad.
2022-04-14 11:22:41 +08:00
prefix (:obj:`str`, *optional*, default to ``""``): A parameters that are used for recursive frozen.
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ``""``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if exclude is None:
exclude = self.unfrozen_modules
if module is None:
module = self.backbone_model
self._freeze_module_recursive(module, exclude, "") # modify the active state dict that still need grad
if set_state_dict:
self.set_active_state_dict(module)
def _freeze_module_recursive(self,
2022-04-14 11:22:41 +08:00
module: Optional[nn.Module] = None,
2022-02-14 21:19:03 +08:00
exclude: Optional[List[str]] = None,
prefix=""):
r"""[NODOC] Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ``_is_delta`` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model, (e.g., bias term)
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The module of which some parts are frozen.
2022-04-14 11:22:41 +08:00
If left with :obj:`None`, the function will the self.backbone_model as the module to be frozen.
exclude (:obj:`List[str]`, *optional*, default to ``["deltas"]``): The parameters that don't need to
2022-02-14 21:19:03 +08:00
be freezed. Default to all the delta parameters.
set_state_dict (:obj:`bool`, *optional*, default to :obj:`True`): Whether setting the backbone model's state
dict to all the parameters that still need grad.
2022-04-14 11:22:41 +08:00
prefix (:obj:`str`, *optional*, default to ``""``): A parameters that are used for recursive frozen.
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ``""``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if is_leaf_module(module):
for n, p in module.named_parameters():
2022-02-15 22:43:28 +08:00
if self.find_key(".".join([prefix,n]), exclude):
2022-02-14 21:19:03 +08:00
continue
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
2022-04-14 11:22:41 +08:00
return
2022-02-14 21:19:03 +08:00
else:
for n, c in module.named_children():
2022-02-15 22:43:28 +08:00
if self.find_key(".".join([prefix,n]), exclude): # if found, untouch the parameters
2022-02-14 21:19:03 +08:00
continue
else: # firstly freeze the non module params, then go deeper.
params = non_module_param(module)
for n, p in params:
if "deltas" not in exclude or (not (hasattr(p, "_is_delta") and getattr(p, "_is_delta"))):
p.requires_grad = False
self._freeze_module_recursive(c, exclude=exclude, prefix=".".join([prefix,n]) )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
2022-02-15 22:43:28 +08:00
def find_key(self, key: str, target_list: List[str]):
2022-04-14 11:22:41 +08:00
r"""Check whether any target string is in the key or in the tail of the key, i.e.,
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Args:
2022-02-15 17:59:53 +08:00
key (:obj:`str`): The key (name) of a submodule in a ancestor module.
2022-02-14 21:19:03 +08:00
E.g., model.encoder.layer.0.attention
2022-02-15 17:59:53 +08:00
target_list (List[Union[:obj:`str`, :obj:`re.Pattern`]]): The target list that we try to match ``key`` with. E.g., ["attention"]
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Returns:
2022-02-14 21:19:03 +08:00
:obj:`bool` True if the key matchs the target list.
"""
2022-04-14 11:22:41 +08:00
for x in self.exclude_modules:
if key.startswith(x): # start with the excluded key
return False
2022-10-14 23:15:38 +08:00
if self.structure_mapping is not None:
key, virtual_key, in_virtual_order = self.structure_mapping.transform(key, strict=False)
# currently in_virtual_order not in use, it means that if the common structure designate adding adapter to FFN, it will be add to all submodule of FFN.
2022-02-14 21:19:03 +08:00
if not key:
return False
2022-10-14 23:15:38 +08:00
if virtual_key is None:
2022-02-15 22:43:28 +08:00
return endswith_in(key, target_list)
2022-10-14 23:15:38 +08:00
else:
return endswith_in(key, target_list) or endswith_in(virtual_key, target_list)
2022-02-14 21:19:03 +08:00
def _pseudo_data_to_instantiate_module(self, module: Optional[nn.Module]=None):
r"""Some delta model requires a pseudo-data be passed through the model to understand the dimensionality of each tensor in the computation graph.
(1) The model in the Huggingface Transformers library usually has the so-called `dummy_inputs`. We will make use of it.
(2) If the model does not have `dummy_inputs`, we will try to create it and throw a warning.
(3) If we encounter an error in (2), we will suggest you to create it by passing the dummy_inputs variable.
2022-02-14 21:19:03 +08:00
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
"""
if module is None:
module = self.backbone_model
2022-06-07 01:52:32 +08:00
device = get_device(module)
_auto_dummy = False
2022-02-14 21:19:03 +08:00
try:
dummy_inputs = module.dummy_inputs
2022-06-07 01:52:32 +08:00
dummy_inputs = move_dict_to_cuda(dummy_inputs, device)
2022-02-14 21:19:03 +08:00
except AttributeError:
logger.warning(f"No `dummy_inputs` attribute in {module.__class__.__name__} , automatically create `dummy_inputs`. Very likely to encounter error. To set dummy_inputs for your model, please use: `setattr(backbone_model, 'dummy_inputs', some_dummy_inputs)` before initializing `{self.__class__.__name__}`")
_auto_dummy = True
pass
if _auto_dummy:
_most_simple_input = torch.tensor([[0,0]]).to(device)
2022-02-14 21:19:03 +08:00
if "decoder_input_ids" in signature(module.forward).args:
dummy_inputs = {"input_ids": _most_simple_input, "decoder_input_ids": _most_simple_input}
2022-02-14 21:19:03 +08:00
else:
dummy_inputs = {"input_ids": _most_simple_input}
_auto_dummy_fail = False
try:
module(**dummy_inputs)
except:
_auto_dummy_fail = True
if _auto_dummy_fail:
raise AttributeError(f"\nThe {self.__class__.__name__} requires a pseudo-data to be passed through the model to understand the dimensionality of each tensor in the computation graph. \nThe automatically created dummy inputs failed.\nThe `dummy_inputs` can be any data that make `backbone_model.forward(**dummy_inputs)` succeed. Only the form and shape of the `dummy_inputs` matter.\n\tTo set dummy_inputs for your model, please use: `setattr(backbone_model, 'dummy_inputs', some_dummy_inputs)` before initializing `{self.__class__.__name__}` ")
2022-02-14 21:19:03 +08:00
def trainable_parameters_names(self, module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the trainable parameter's name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the trainable paramemters' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
:obj:`List[str]`
"""
if module is None:
module = self.backbone_model
return [n for n,p in module.named_parameters() if p.requires_grad]
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def frozen_parameters_names(self, module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the frozen parameters' name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the frozen paramemters' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
:obj:`List[str]`
"""
if module is None:
module = self.backbone_model
return [n for n,p in module.named_parameters() if not p.requires_grad]
def trainable_parameters(self,module: Optional[nn.Module]=None):
r"""[NODOC] A small sugar function to return all the frozen parameters in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the frozen paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
return [p for n,p in module.named_parameters() if p.requires_grad]
def num_trainable_parameters(self, module: Optional[nn.Module]=None):
2022-04-14 11:22:41 +08:00
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
pnum_tot = 0
for param in module.parameters():
if param.requires_grad:
pnum_tot += param.numel()
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def num_total_parameters(self, module: Optional[nn.Module]=None):
2022-04-14 11:22:41 +08:00
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self
pnum_tot = 0
for param in module.parameters():
pnum_tot += param.numel()
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def find_module(self, root_module: nn.Module, key:str):
r"""Find the module using a key and the root module. Return both the parent reference, the child name and reference.
Args:
root_module (:obj:`root_module`): The root_module to find the sub module in
2022-04-14 11:22:41 +08:00
key (:obj:`str`): The relative key to the root module.
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
(:obj:`nn.Module`, :obj:`str`, :obj:`nn.Module`):
* A reference to the parent module of the target module, mainly for substuting the target module.
2022-02-14 21:19:03 +08:00
* The key of the target module relevant to its parent module
* Target module.
"""
sub_keys = key.split(".")
parent_module = root_module
for sub_key in sub_keys[:-1]:
parent_module = getattr(parent_module, sub_key)
module = getattr(parent_module, sub_keys[-1])
return parent_module, sub_keys[-1], module
def _register_delta_infos(self, parent_module, _delta_info):
r"""Register the delta infomation.
Automatically incrementing the suffix for repeated delta_names
"""
_delta_infos = getattr(parent_module, "_delta_infos", [])
if len(_delta_infos) > 0: # check if duplicated name
list_of_deltas = [d['delta_name'] for d in _delta_infos]
cur_name = _delta_info['delta_name']
if cur_name in list_of_deltas:
cur_name = cur_name + "_1"
counter = 1
while cur_name in list_of_deltas:
counter += 1
cur_name = cur_name.split("_")[0] + "_"+str(counter)
_delta_info["delta_name"] = cur_name
_delta_infos.append(_delta_info)
setattr(parent_module, "_delta_infos", _delta_infos)
def replace_module(self,
2022-04-14 11:22:41 +08:00
parent_module: nn.Module,
2022-02-14 21:19:03 +08:00
child_name: str,
child_module: nn.Module,
new_module: nn.Module,
delta_name: Optional[str] = "delta",
):
2022-04-14 11:22:41 +08:00
r"""Replace a module's child module with the new_module(a delta module). Used by delta method based on direct
2022-02-14 21:19:03 +08:00
replacement, such as :class:`opendelta.delta_modules.lora.LoraModel`.
Args:
parent_module (:obj:`nn.Module`): The parent module of the replacement.
child_name (:obj:`str`): The chird module's name, i.e., parent_module.child_name give us child_module
child_module (:obj:`nn.Module`): The original child module.
new_module (:obj:`nn.Module`): The delta module.
delta_name (:obj:`str`, *optional*, default ot ``delta``): The name of the delta module, used for recording.
parent_module.delta_name WILL NOT give you the delta module.
"""
self.delta_modules.append(new_module)
setattr(parent_module, child_name, new_module)
# register delta info
2022-04-14 11:22:41 +08:00
_delta_info = {"method": "replace",
"delta_module": new_module,
2022-02-14 21:19:03 +08:00
"child_name": child_name,
"org_module": child_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=parent_module,
_delta_info = _delta_info,
)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module(self, module: nn.Module):
r"""Modify the inside parameteres of a module. This method will be reimplemented in different
derived class if needed.
"""
raise NotImplementedError
2022-02-20 17:23:31 +08:00
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
2022-04-14 11:22:41 +08:00
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
2022-02-14 21:19:03 +08:00
function of the original module to firstly pass the arguments into the new module's forward function and then pass
2022-04-14 11:22:41 +08:00
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
2022-02-14 21:19:03 +08:00
When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
2022-04-14 11:22:41 +08:00
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
2022-02-14 21:19:03 +08:00
original delta is passed through ``_delta_info``.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):# is not None:
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):# is not None:
ret = delta_module.post_forward(ret)
return ret
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
2022-04-14 11:22:41 +08:00
_delta_info = {"method": "insert_sequential",
"delta_module": delta_module,
2022-02-20 17:23:31 +08:00
"delta_name": delta_name,
2022-02-14 21:19:03 +08:00
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
2022-02-20 17:23:31 +08:00
delta_name = _delta_info["delta_name"]
2022-02-14 21:19:03 +08:00
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
2022-04-14 11:22:41 +08:00
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
2022-02-14 21:19:03 +08:00
2022-02-20 17:23:31 +08:00
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
2022-04-14 11:22:41 +08:00
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the delta model's forward function and set
2022-02-14 21:19:03 +08:00
aside the calculation result. Then combine it with the calculation result output from the backbone module.
When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function.
2022-02-20 17:23:31 +08:00
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
2022-04-14 11:22:41 +08:00
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
2022-02-20 17:23:31 +08:00
original delta is passed through ``_delta_info``.
2022-02-14 21:19:03 +08:00
"""
2022-02-20 17:23:31 +08:00
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
2022-04-14 11:22:41 +08:00
2022-02-20 17:23:31 +08:00
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
2022-04-14 11:22:41 +08:00
2022-02-20 17:23:31 +08:00
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
2022-04-14 11:22:41 +08:00
_delta_info = {"method": "insert_parallel",
"delta_module": delta_module,
2022-02-20 17:23:31 +08:00
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
2022-04-14 11:22:41 +08:00
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
2022-02-14 21:19:03 +08:00
def set_active_state_dict(self, module: nn.Module):
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
Args:
module (:obj:`nn.Module`): The module modified. The modification is in-place.
"""
def _caller(_org_func, includes, *args, **kwargs):
state_dict = _org_func(*args, **kwargs)
keys = list(state_dict.keys())
for n in keys:
if n not in includes:
state_dict.pop(n)
return state_dict
includes = self.trainable_parameters_names(module) # use excludes will have trouble when the model have shared weights
if hasattr(module.state_dict, "__wrapped__"):
2022-03-13 01:21:55 +08:00
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended? Do you freeze the parameters twice?")
2022-02-14 21:19:03 +08:00
module.state_dict = decorate(module.state_dict, _caller, extras=(includes,), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def _load_state_dict_into_backbone(self, backbone_model: nn.Module = None, state_dict: dict = {}):
r"""[NODOC]
"""
if backbone_model is None:
backbone_model = self.backbone_model
self.backbone_model.load_state_dict(state_dict, strict=False)
def create_config_from_model(self, ):
r"""[NODOC] If the delta model was built by directly passing arguments, instead of passing a config object.
create the config of the delta model for saving the delta model.
"""
# common_attributes
config = self.config_class()
config_keys = signature(config.__init__)[0] + signature(super(self.config_class, config).__init__)[0]
for key in config_keys:
val = getattr(self, key) if hasattr(self, key) else None
setattr(config, key, val)
config.delta_type = self.delta_type
self.config = config
2022-04-14 11:22:41 +08:00
2022-03-20 03:09:00 +08:00
def log(self, module=None, delta_ratio=True, trainable_ratio=True, visualization=True, cuda_memory=True):
2022-04-14 11:22:41 +08:00
r"""Log and visualize the result of applying delta.
2022-02-14 21:19:03 +08:00
Possible Options are ``trainable_ratio``,
``visualization``, ``delta_ratio``.
Args:
delta_ratio (:obj:`bool`, *optional*): Whether computing the ratio of parameters in the delta modules.
trainable_ratio (:obj:`bool`, *optional*): Whether computing the ratio of trainable parameters.
visualization (:obj:`bool`, *optional*): Whether visualize the parameter information of the modified backbone.
"""
if module is None:
module = self.backbone_model
if visualization:
from opendelta import Visualization
Visualization(module).structure_graph()
2022-06-08 16:55:02 +08:00
self.get_statistics(module)
2022-02-14 21:19:03 +08:00
if trainable_ratio:
2022-06-08 16:55:02 +08:00
logger.info("Trainable Ratio: {:2f}%".format(self.stat['trainable_ratio']*100))
2022-02-14 21:19:03 +08:00
if delta_ratio:
2022-06-08 16:55:02 +08:00
logger.info("Delta Parameter Ratio: {:2f}%".format(self.stat['delta_ratio']*100))
2022-03-20 03:09:00 +08:00
if cuda_memory:
2022-06-08 16:55:02 +08:00
logger.info("Static Memory {:.2f} GB, Max Memory {:.2f} GB".format(self.stat['cudamem'], self.stat['maxcudamem']))
def get_statistics(self, module=None):
r"""Get the statistics of the parameters in the delta modules.
Args:
module (:obj:`nn.Module`, *optional*): The module to compute the statistics.
Returns:
:obj:`dict`: The statistics of the parameters in the delta modules.
"""
if module is None:
module = self.backbone_model
self.stat = {}
n_trainable = self.num_trainable_parameters(module)
n_total = self.num_total_parameters(module)
self.stat['trainable_ratio'] = n_trainable/n_total
n_delta = self.num_delta_parameters(module)
n_total = self.num_total_parameters(module)
self.stat['delta_ratio'] = n_delta/n_total
cudamem = 0
maxcudamem = 0
for device_id in range(torch.cuda.device_count()):
cudamem += torch.cuda.memory_allocated(f"cuda:{device_id}")/1024**3
maxcudamem += torch.cuda.max_memory_allocated(f"cuda:{device_id}")/1024**3
self.stat['cudamem'] = cudamem
self.stat['maxcudamem'] = maxcudamem
2022-03-20 03:09:00 +08:00
2022-02-14 21:19:03 +08:00
def num_delta_parameters(self, module: Optional[nn.Module]=None):
2022-04-14 11:22:41 +08:00
r"""[NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate.
2022-04-14 11:22:41 +08:00
Args:
2022-02-14 21:19:03 +08:00
module (:obj:`nn.Module`): of which module we want to know the number of trainable paramemters.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns:
2022-04-14 11:22:41 +08:00
:obj:`List[nn.Parameter]`
2022-02-14 21:19:03 +08:00
"""
if module is None:
module = self.backbone_model
pnum_tot = 0
for param in module.parameters():
if hasattr(param, "_is_delta"):
pnum_tot += param.numel()
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# Two functions for plug and remove the delta model.
2022-03-20 10:48:49 +08:00
def attach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
2022-02-14 21:19:03 +08:00
r"""Reattach the delta modules to the backbone. Note that this method can not be used to create new delta modules.
2022-04-14 11:22:41 +08:00
Instead, a :meth:`DeltaBase.detach` should precede this method.
2022-02-14 21:19:03 +08:00
Args:
2022-04-14 11:22:41 +08:00
module (:obj:`object`, *optional*, default to :obj:`None`): The backbone module that we
2022-02-14 21:19:03 +08:00
reattach the deltas to.
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None:
module = self.backbone_model
for name, submodule in module.named_modules():
if hasattr(submodule, "_delta_infos"):
2022-04-14 11:22:41 +08:00
_delta_infos = getattr(submodule, "_delta_infos")
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos:
if _delta_info['delta_belong'] is not self:
continue
if _delta_info["state"] == "on":
continue
if _delta_info['method'] == "replace":
setattr(submodule, _delta_info["child_name"], _delta_info['delta_module'])
elif _delta_info['method'] == "insert_sequential":
2022-04-14 11:22:41 +08:00
self.insert_sequential_module(module=submodule,
2022-02-14 21:19:03 +08:00
_delta_info=_delta_info)
2022-04-18 23:28:13 +08:00
elif _delta_info['method'] == "insert_parallel":
self.insert_parallel_module(module=submodule,
_delta_info=_delta_info)
2022-02-14 21:19:03 +08:00
else:
raise NotImplementedError
2022-04-14 11:22:41 +08:00
_delta_info['state'] = "on"
2022-03-20 10:48:49 +08:00
if reset_state_dict:
self.set_active_state_dict(module)
2022-02-14 21:19:03 +08:00
2022-03-20 10:48:49 +08:00
def detach(self, module: Optional[nn.Module]=None, reset_state_dict=True):
2022-02-14 21:19:03 +08:00
r"""Detach the delta module from the backbone. The delta module is not deleted, but temporarily turned off.
Use :meth:`DeltaBase.attach` to reattach the delta model to the backbone.
Args:
2022-04-14 11:22:41 +08:00
module (:obj:`object`, *optional*, default to :obj:`None`): The backbone module that we
2022-02-14 21:19:03 +08:00
detached the deltas from.
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None:
module = self.backbone_model
for name, submodule in module.named_modules():
if hasattr(submodule, "_delta_infos"):
2022-04-14 11:22:41 +08:00
_delta_infos = getattr(submodule, "_delta_infos")
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos:
if _delta_info['delta_belong'] is not self:
continue
if _delta_info["state"] == "off":
continue
if _delta_info['method'] == "replace":
setattr(submodule, _delta_info["child_name"], _delta_info['org_module'])
elif _delta_info['method'] == "insert_sequential":
if hasattr(submodule.forward, "__wrapped__"):
submodule.forward = submodule.forward.__wrapped__
delattr(submodule, _delta_info["delta_name"])
else:
2022-04-18 23:28:13 +08:00
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
elif _delta_info['method'] == "insert_parallel":
if hasattr(submodule.forward, "__wrapped__"):
submodule.forward = submodule.forward.__wrapped__
delattr(submodule, _delta_info["delta_name"])
else:
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
2022-02-14 21:19:03 +08:00
else:
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
_delta_info['state'] = "off"
2022-03-20 10:48:49 +08:00
if reset_state_dict:
try:
module.state_dict = module.state_dict.__wrapped__
except AttributeError:
pass
2022-04-14 11:22:41 +08:00