2022-06-06 16:21:55 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
from typing import Optional, Union
|
|
|
|
from opendelta.utils.signature import get_arg_names_inside_func
|
|
|
|
from opendelta.utils.name_based_addressing import *
|
|
|
|
from opendelta.utils.cuda import get_device
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch
|
|
|
|
from opendelta.delta_models.layers.activations import Activations
|
|
|
|
from opendelta import BaseDeltaConfig
|
|
|
|
import opendelta.utils.logging as logging
|
2022-03-13 01:21:55 +08:00
|
|
|
import numpy as np
|
|
|
|
from opendelta import global_setting
|
2022-07-01 22:23:02 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
|
|
|
|
class InterFaceMixin:
|
|
|
|
def __init__(self):
|
|
|
|
self._axis_order = global_setting.axis_order
|
|
|
|
self._reverse_axis_order = np.argsort(self._axis_order).tolist()
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
def _transpose(self, tensor):
|
2022-07-06 22:00:58 +08:00
|
|
|
if tensor.dim() == 3:
|
|
|
|
return tensor.permute(*self._axis_order)
|
|
|
|
else:
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
def _reverse_transpose(self, tensor):
|
2022-07-06 22:00:58 +08:00
|
|
|
if tensor.dim() == 3:
|
|
|
|
return tensor.permute(*self._reverse_axis_order).contiguous()
|
|
|
|
else:
|
|
|
|
return tensor
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
def _convert_data_type(self, tensor):
|
|
|
|
self._data_type_record = tensor.dtype
|
|
|
|
self._device_record = tensor.device
|
|
|
|
return tensor.to(torch.float32).to(self._get_device())
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
def _reverse_data_type(self, tensor):
|
|
|
|
return tensor.to(self._data_type_record).to(self._device_record)
|
|
|
|
|
|
|
|
|
|
|
|
|
2022-07-06 22:00:58 +08:00
|
|
|
|
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
class AdapterLayer(nn.Module, InterFaceMixin):
|
2022-04-14 11:22:41 +08:00
|
|
|
r"""A layer of adapter tuning module.
|
2022-02-14 21:19:03 +08:00
|
|
|
"""
|
|
|
|
layer_count = 0
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def count_layer(cls):
|
|
|
|
cls.layer_count += 1
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
@classmethod
|
|
|
|
def get_layer_count(cls):
|
|
|
|
return cls.layer_count
|
|
|
|
|
|
|
|
def __init__(self, bottleneck_dim=24, non_linearity='gelu_new', device=None):
|
|
|
|
super().__init__()
|
2022-03-13 01:21:55 +08:00
|
|
|
InterFaceMixin.__init__(self)
|
2022-02-14 21:19:03 +08:00
|
|
|
self.bottleneck_dim = bottleneck_dim
|
2022-03-13 01:21:55 +08:00
|
|
|
self.init_device = device
|
2022-02-14 21:19:03 +08:00
|
|
|
self.instantiated = False
|
|
|
|
self.non_linearity = non_linearity
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
self.layer_id = AdapterLayer.get_layer_count()
|
|
|
|
AdapterLayer.count_layer()
|
2022-03-13 01:21:55 +08:00
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
|
|
|
|
def _get_device(self):
|
|
|
|
if self.instantiated:
|
|
|
|
return self.modulelist.down_proj.weight.device
|
|
|
|
else:
|
|
|
|
return self.init_device
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.modulelist = nn.Sequential()
|
2022-03-13 01:21:55 +08:00
|
|
|
self.modulelist.add_module("down_proj",nn.Linear(hidden_dim, self.bottleneck_dim, device=self.init_device))
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
# select non-linearity
|
|
|
|
self.modulelist.add_module("non_linear", Activations(self.non_linearity.lower()))
|
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.init_device))
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
# TODO:
|
|
|
|
# If we want to have a layer norm on output, we apply it later after a separate residual connection
|
|
|
|
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
|
|
|
|
# if self.add_layer_norm_after:
|
|
|
|
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
|
|
|
|
|
|
|
|
self.instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
# initialize the weight, which is important for fast convergence and better performance.
|
2022-02-14 21:19:03 +08:00
|
|
|
self.apply(self._init_weight)
|
2022-09-03 18:12:12 +08:00
|
|
|
try:
|
|
|
|
import bmtrain as bmt
|
|
|
|
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
|
|
|
|
except:
|
|
|
|
pass
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def _init_weight(self, module):
|
|
|
|
if isinstance(module, nn.Linear):
|
2022-04-14 11:22:41 +08:00
|
|
|
module.weight.data.normal_(mean=0.0, std=0.01)
|
2022-02-14 21:19:03 +08:00
|
|
|
if module.bias is not None:
|
|
|
|
module.bias.data.zero_()
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def post_forward(self, output):
|
2022-04-14 11:22:41 +08:00
|
|
|
r""" Get the hidden_states from the PLM's layer output, pass it into the adapter,
|
2022-02-14 21:19:03 +08:00
|
|
|
then combined with the main hidden_states. Finally pass it into the subsequent layer.
|
|
|
|
|
|
|
|
"""
|
|
|
|
if isinstance(output, tuple):
|
|
|
|
hiddens = output[0]
|
|
|
|
elif isinstance(output, torch.Tensor):
|
|
|
|
hiddens = output
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
hiddens = self._transpose(hiddens)
|
|
|
|
hiddens = self._convert_data_type(hiddens)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = hiddens.shape[-1]
|
|
|
|
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
adapter_output = self.modulelist(hiddens)
|
|
|
|
modified_output = adapter_output + hiddens # TODO option: disable residual_connection
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 01:21:55 +08:00
|
|
|
modified_output = self._reverse_transpose(modified_output)
|
|
|
|
modified_output = self._reverse_data_type(modified_output)
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if isinstance(output, tuple):
|
|
|
|
output = (modified_output,) + output[1:]
|
|
|
|
elif isinstance(output, torch.Tensor):
|
|
|
|
output = modified_output
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
return output
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
class AdapterConfig(BaseDeltaConfig):
|
|
|
|
r"""
|
|
|
|
This is the configuration class to store the configuration of a :py:class:`~AdapterModel`
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
2022-04-14 11:22:41 +08:00
|
|
|
self,
|
|
|
|
bottleneck_dim: Optional[int]=24,
|
2022-02-14 21:19:03 +08:00
|
|
|
non_linearity: Optional[str]='gelu_new',
|
|
|
|
**kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
):
|
2022-02-14 21:19:03 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdapterModel(DeltaBase):
|
|
|
|
r""" The implementation of Adapter(`Parameter-Efficient Transfer Learning for NLP <https://arxiv.org/abs/1902.00751>`_ ) .
|
2022-04-14 11:22:41 +08:00
|
|
|
Add adapter to the designated ``modified_modules``. In sequential paradigm, The modules' output is then passed into the adapter's
|
|
|
|
post_forward.
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
.. note::
|
2022-04-14 11:22:41 +08:00
|
|
|
We **assume** the output of the modified module is the hidden state or a tuple where hidden state is the
|
|
|
|
first element. This is true for most PLMs. However, we admit that currently it's not rigorous, We will improve
|
|
|
|
it in the next version. Currently, if you encount an error here for you backbone, you can modify the code to
|
2022-02-14 21:19:03 +08:00
|
|
|
get the hidden state.
|
|
|
|
|
|
|
|
class attributes:
|
|
|
|
- default_modified_modules = ["attn", "ff"] According to the Adapter paper, we add adapter to the attention layer
|
2022-04-14 11:22:41 +08:00
|
|
|
and feed forward layer.
|
2022-02-14 21:19:03 +08:00
|
|
|
- delta_type = "adapter"
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
|
|
|
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
|
2022-02-14 21:19:03 +08:00
|
|
|
non_linearity (:obj:`str`): The non linearity of the adapter.
|
2022-02-24 23:21:31 +08:00
|
|
|
modified_modules (:obj:`List[str]`): modules to add adapter after them.
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the adapter parameters.
|
2022-02-14 21:19:03 +08:00
|
|
|
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
|
|
|
|
|
|
|
"""
|
|
|
|
config_class = AdapterConfig
|
|
|
|
delta_type = "adapter"
|
2022-10-14 23:15:38 +08:00
|
|
|
default_modified_modules = ["attn@.proj@", "ff@.w2@"]
|
|
|
|
_need_pseudo_data = True
|
2022-02-14 21:19:03 +08:00
|
|
|
def __init__(self,
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model: nn.Module,
|
|
|
|
bottleneck_dim: Optional[int]=24,
|
2022-02-14 21:19:03 +08:00
|
|
|
non_linearity: Optional[str]='gelu_new',
|
|
|
|
modified_modules: Optional[bool] = None,
|
|
|
|
unfrozen_modules: Optional[bool] = None,
|
|
|
|
common_structure: Optional[bool] = None,
|
|
|
|
interactive_modify: Optional[Union[bool, int]] = False,
|
|
|
|
):
|
2022-04-14 11:22:41 +08:00
|
|
|
DeltaBase.__init__(self,
|
|
|
|
backbone_model,
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules=modified_modules,
|
2022-04-14 11:22:41 +08:00
|
|
|
exclude_modules=exclude_modules,
|
2022-02-14 21:19:03 +08:00
|
|
|
unfrozen_modules=unfrozen_modules,
|
|
|
|
common_structure=common_structure,
|
|
|
|
interactive_modify=interactive_modify,
|
|
|
|
)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # not registered in parent class
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
self.delta_modules = nn.ModuleList()
|
|
|
|
|
|
|
|
self.add_all_delta_to_backbone(self.backbone_model,
|
|
|
|
self.modified_modules,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def update_module(self, module: nn.Module, key: str):
|
|
|
|
_, _, ref = self.find_module(module, key)
|
|
|
|
adapterlayer = self.new_module_like(ref)
|
2022-02-20 17:23:31 +08:00
|
|
|
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="adapter")
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def new_module_like(self, module):
|
|
|
|
module_device = get_device(module)
|
|
|
|
adapterlayer = AdapterLayer(bottleneck_dim=self.bottleneck_dim, non_linearity=self.non_linearity, device=module_device)
|
2022-04-14 11:22:41 +08:00
|
|
|
self.delta_modules.append(adapterlayer)
|
2022-02-14 21:19:03 +08:00
|
|
|
return adapterlayer
|