2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
|
from opendelta.delta_configs import BaseDeltaConfig
|
|
|
|
|
from opendelta.delta_models.layers.low_rank_linear import LowRankLinear
|
|
|
|
|
from opendelta.delta_models.layers.activations import Activations
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
from opendelta.utils.signature import get_arg_names_inside_func
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch
|
|
|
|
|
from typing import Optional
|
|
|
|
|
from opendelta.utils.name_based_addressing import *
|
|
|
|
|
from opendelta.utils.cuda import get_device
|
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch
|
|
|
|
|
import math
|
|
|
|
|
import opendelta.utils.logging as logging
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankAdapterConfig(BaseDeltaConfig):
|
|
|
|
|
r"""
|
|
|
|
|
This is the configuration class to store the configuration of a :py:class:`~LowRankAdapterModel`
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
def __init__(
|
2022-04-14 11:22:41 +08:00
|
|
|
|
self,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
reduction_factor=32,
|
|
|
|
|
non_linearity="gelu_new",
|
|
|
|
|
low_rank_w_init="glorot-uniform",
|
|
|
|
|
low_rank_rank=1,
|
|
|
|
|
**kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
):
|
2022-02-14 21:19:03 +08:00
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
|
for arg_name in arg_names:
|
|
|
|
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankAdapter(nn.Module):
|
|
|
|
|
"""This is the low-rank adapter, in which each adapter is composed of two rank-one matrices.
|
|
|
|
|
"""
|
2022-04-14 11:22:41 +08:00
|
|
|
|
def __init__(self,
|
|
|
|
|
reduction_factor=32,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
non_linearity="gelu_new",
|
2022-04-14 11:22:41 +08:00
|
|
|
|
low_rank_w_init="glorot-uniform",
|
2022-02-14 21:19:03 +08:00
|
|
|
|
low_rank_rank=1,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
device=None,
|
|
|
|
|
backend='hf'):
|
2022-02-14 21:19:03 +08:00
|
|
|
|
super().__init__()
|
|
|
|
|
self.reduction_factor = reduction_factor
|
|
|
|
|
self.non_linearity = non_linearity
|
|
|
|
|
self.low_rank_w_init = low_rank_w_init
|
|
|
|
|
self.low_rank_rank = low_rank_rank
|
|
|
|
|
self.device = device
|
|
|
|
|
self.instantiated = False
|
2022-10-23 16:42:21 +08:00
|
|
|
|
self.backend=backend
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-10-23 16:42:21 +08:00
|
|
|
|
def instantiate(self, hiddens):
|
|
|
|
|
self.hidden_dim = hiddens.shape[-1]
|
|
|
|
|
self.hidden_dtype = hiddens.dtype
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-10-23 16:42:21 +08:00
|
|
|
|
self.down_sample_size = self.hidden_dim // self.reduction_factor
|
2022-02-14 21:19:03 +08:00
|
|
|
|
self.activation = Activations(self.non_linearity.lower()).to(self.device)
|
2022-10-23 16:42:21 +08:00
|
|
|
|
self.down_sampler = LowRankLinear(self.hidden_dim, self.down_sample_size,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
w_init=self.low_rank_w_init,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
rank=self.low_rank_rank,
|
|
|
|
|
dtype=self.hidden_dtype).to(self.device)
|
|
|
|
|
self.up_sampler = LowRankLinear(self.down_sample_size, self.hidden_dim,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
w_init=self.low_rank_w_init,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
rank=self.low_rank_rank,
|
|
|
|
|
dtype=self.hidden_dtype).to(self.device)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
self.instantiated = True
|
2022-10-23 16:42:21 +08:00
|
|
|
|
if self.backend == 'bmt':
|
2022-09-03 18:12:12 +08:00
|
|
|
|
import bmtrain as bmt
|
|
|
|
|
self.activation = bmt.BMTrainModelWrapper(self.activation)
|
|
|
|
|
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
|
|
|
|
|
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
|
2022-10-23 16:42:21 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
def post_forward(self, output):
|
2022-04-14 11:22:41 +08:00
|
|
|
|
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
then combined with the main hidden_states. Finally pass it into the subsequent layer.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
|
hiddens = output[0]
|
|
|
|
|
elif isinstance(output, torch.Tensor):
|
|
|
|
|
hiddens = output
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
if not self.instantiated:
|
2022-10-23 16:42:21 +08:00
|
|
|
|
self.instantiate(hiddens = hiddens)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
z = self.down_sampler(hiddens)
|
|
|
|
|
z = self.activation(z)
|
|
|
|
|
adapter_output = self.up_sampler(z)
|
|
|
|
|
|
|
|
|
|
modified_output = adapter_output + hiddens # residual_connection
|
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
|
output = (modified_output,) + output[1:]
|
|
|
|
|
elif isinstance(output, torch.Tensor):
|
|
|
|
|
output = modified_output
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowRankAdapterModel(DeltaBase):
|
2022-04-14 11:22:41 +08:00
|
|
|
|
r""" The implementation of LowRankAdapter, proposed as a baseline in
|
2022-02-14 21:19:03 +08:00
|
|
|
|
`Compacter: Efficient Low-Rank Hypercomplex Adapter Layers <https://arxiv.org/abs/2106.04647>`_ .
|
|
|
|
|
We found that it enjoys very few parameters but competitive performance, thus add it into OpenDelta.
|
|
|
|
|
Low Rank Adapter parameterize each adapter’s weight as a product of two rank-one(low) weights.
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
Add lowrank adapter layer to the designated ``modified_modules``. In sequential paradigm, The modules' output is then
|
|
|
|
|
passed into the low rank adapter's post_forward.
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
.. note::
|
2022-04-14 11:22:41 +08:00
|
|
|
|
We **assume** the output of the modified module is the hidden state or a tuple where hidden state is the
|
|
|
|
|
first element. This is true for most PLMs. However, we admit that currently it's not rigorous, We will improve
|
|
|
|
|
it in the next version. Currently, if you encount an error here for you backbone, you can modify the code to
|
2022-02-14 21:19:03 +08:00
|
|
|
|
get the hidden state.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
All the hyperparameter is adopted from the `compacter code base <https://github.com/rabeehk/compacter>`_ .
|
|
|
|
|
|
|
|
|
|
class attributes:
|
|
|
|
|
- default_modified_modules = ["attn", "ff"] According to the compacter paper, we add low rank adapter to the attention layer
|
2022-04-14 11:22:41 +08:00
|
|
|
|
and feed forward layer.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
- delta_type = "lowrankadapter"
|
|
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
|
|
|
|
reduction_factor (:obj:`int`, *optional*, default to ``16``): bottleneck_dim = hidden_dim//reduction_factor
|
|
|
|
|
non_linearity (:obj:`str`, *optional*, default to ``"gelu_new"``): The non linearity activation used in between the down
|
|
|
|
|
projecter and the up projecter.
|
|
|
|
|
low_rank_w_init (:obj:`str`, *optional*, default to ``"glorot-uniform"``): The weight init method of the factorized
|
2022-02-14 21:19:03 +08:00
|
|
|
|
linear weight.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
low_rank_rank (:obj:`int`, *optional*, default to 1): The rank of the low-rank decomposition.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
|
|
|
|
the implemented ones)
|
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
|
|
|
|
together with the prefix parameters.
|
2022-03-19 15:04:42 +08:00
|
|
|
|
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): whether using name-based addressing with a common structure mapping.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
config_class = LowRankAdapterConfig
|
2022-03-20 10:48:49 +08:00
|
|
|
|
delta_type = "low_rank_adapter"
|
2022-10-14 23:15:38 +08:00
|
|
|
|
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
2022-10-23 16:42:21 +08:00
|
|
|
|
_supported_backends = ['hf', 'bmt']
|
2022-10-14 23:15:38 +08:00
|
|
|
|
_need_pseudo_data = True
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def __init__(self,
|
2022-04-14 11:22:41 +08:00
|
|
|
|
backbone_model: nn.Module,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
reduction_factor = 32,
|
|
|
|
|
non_linearity = "gelu_new",
|
2022-04-14 11:22:41 +08:00
|
|
|
|
low_rank_w_init = "glorot-uniform",
|
2022-02-14 21:19:03 +08:00
|
|
|
|
low_rank_rank = 1,
|
2022-03-19 15:04:42 +08:00
|
|
|
|
modified_modules: Optional[List[str]] = None,
|
2022-04-14 11:22:41 +08:00
|
|
|
|
exclude_modules: Optional[List[str]] = None,
|
2022-03-19 15:04:42 +08:00
|
|
|
|
unfrozen_modules: Optional[List[str]] = None,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
common_structure: Optional[bool] = None,
|
|
|
|
|
interactive_modify: Optional[Union[bool, int]] = False,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
backend: Optional[str] = 'hf',
|
2022-02-14 21:19:03 +08:00
|
|
|
|
):
|
2022-04-14 11:22:41 +08:00
|
|
|
|
DeltaBase.__init__(self,
|
|
|
|
|
backbone_model,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
modified_modules=modified_modules,
|
2022-04-14 11:22:41 +08:00
|
|
|
|
exclude_modules=exclude_modules,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
unfrozen_modules=unfrozen_modules,
|
|
|
|
|
common_structure=common_structure,
|
|
|
|
|
interactive_modify=interactive_modify,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
backend=backend,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
)
|
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
|
for arg_name in arg_names:
|
|
|
|
|
if not hasattr(self, arg_name): # not registered in parent class
|
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
self.delta_modules = nn.ModuleList()
|
|
|
|
|
|
|
|
|
|
self.add_all_delta_to_backbone(self.backbone_model,
|
|
|
|
|
self.modified_modules,
|
|
|
|
|
)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
|
2022-10-14 23:15:38 +08:00
|
|
|
|
# def add_all_delta_to_backbone(self,
|
|
|
|
|
# module: nn.Module,
|
|
|
|
|
# modified_modules: List[str],
|
|
|
|
|
# ) -> nn.Module:
|
|
|
|
|
# for key, _ in module.named_modules():
|
|
|
|
|
# if self.find_key(key, modified_modules):
|
|
|
|
|
# self.update_module(module, key)
|
|
|
|
|
# self._pseudo_data_to_instantiate(module)
|
|
|
|
|
# self.mark_as_delta()
|
|
|
|
|
# return module
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def update_module(self, module: nn.Module, key: str):
|
|
|
|
|
_, _, ref = self.find_module(module, key)
|
|
|
|
|
adapterlayer = self.new_module_like(ref)
|
2022-02-20 17:23:31 +08:00
|
|
|
|
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="low_rank_adapter")
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def new_module_like(self, module):
|
|
|
|
|
module_device = get_device(module)
|
|
|
|
|
adapterlayer = LowRankAdapter(reduction_factor = self.reduction_factor,
|
|
|
|
|
non_linearity = self.non_linearity,
|
2022-04-14 11:22:41 +08:00
|
|
|
|
low_rank_w_init = self.low_rank_w_init,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
low_rank_rank = self.low_rank_rank,
|
2022-10-23 16:42:21 +08:00
|
|
|
|
device=module_device, backend=self.backend)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
self.delta_modules.append(adapterlayer)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
return adapterlayer
|