OpenDeltaMirror/opendelta/delta_models/bitfit.py

202 lines
7.9 KiB
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
from typing import Optional, Union
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase, is_leaf_module
import torch.nn as nn
import torch
from torch.nn import init
import math
from opendelta import BaseDeltaConfig
import opendelta.utils.logging as logging
logger = logging.get_logger(__name__)
class BitFitConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~BitFitModel`
"""
def __init__(
2022-04-14 11:22:41 +08:00
self,
2022-02-14 21:19:03 +08:00
**kwargs
2022-04-14 11:22:41 +08:00
):
2022-02-14 21:19:03 +08:00
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module):
def __init__(self, init_method="zero"):
super().__init__()
self.init_method=init_method
self.instantiated = False
def instantiate(self, hidden_dim):
if self.init_method == "zero":
self.bias = nn.Parameter(torch.zeros(hidden_dim))
else:
raise NotImplementedError
self.instantiated = True
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension.
In most cases, it is correct. However, be aware of the possibility that the presumption
2022-04-14 11:22:41 +08:00
doesn't hold.
2022-02-14 21:19:03 +08:00
"""
if isinstance(output, tuple):
hiddens = output[0]
elif isinstance(output, torch.Tensor):
hiddens = output
else:
raise TypeError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
modified_output = hiddens + self.bias
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = modified_output
else:
raise TypeError
return output
class BitFitModel(DeltaBase):
r""" The implementation of `BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models <https://arxiv.org/abs/2106.10199>`_ .
Unfreeze bias term (or add bias term if bias term is absent in the backbone, e.g. T5) to the modules of
2022-04-14 11:22:41 +08:00
a transformer block.
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
.. note::
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
**Broadcast to Submodule**: We modify all potential positions of the specified
2022-02-14 21:19:03 +08:00
``modified_modules``. That is to say, if we specify ``attn`` in the modified_modules, then all position
including the q, k, v and out linear layer of the attention layer are added bias layer (or unfreezing).
2022-04-14 11:22:41 +08:00
The potential position is determined according to equation (1)-(5) and the previous three
2022-02-14 21:19:03 +08:00
equations.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
class attributes:
2022-04-14 11:22:41 +08:00
- default_modified_modules = ["attn", "ff", "layer_norm","lm_head.proj"] According to the paper and the
2022-02-14 21:19:03 +08:00
implementation in `Compacter's baseline <https://github.com/rabeehk/compacter>`_ , we modify the
2022-04-14 11:22:41 +08:00
bias term in the above modules.
2022-02-14 21:19:03 +08:00
- delta_type = "bitfit"
Args:
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
2022-02-14 21:19:03 +08:00
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
2022-03-19 15:04:42 +08:00
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
2022-02-14 21:19:03 +08:00
"""
config_class = BitFitConfig
delta_type = "bitfit"
default_modified_modules = ["attn@", "ff@", "layer_norm@","lm_head@.proj@"] # modify all the bias parameter in attention and feed-forward layer.
_need_pseudo_data = False
2022-02-14 21:19:03 +08:00
def __init__(self,
2022-04-14 11:22:41 +08:00
backbone_model: nn.Module,
2022-03-19 15:04:42 +08:00
modified_modules: Optional[List[str]] = None,
2022-04-14 11:22:41 +08:00
exclude_modules: Optional[List[str]] = None,
2022-03-19 15:04:42 +08:00
unfrozen_modules: Optional[List[str]] = None,
2022-02-14 21:19:03 +08:00
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
2022-04-14 11:22:41 +08:00
DeltaBase.__init__(self,
backbone_model,
2022-02-14 21:19:03 +08:00
modified_modules=modified_modules,
2022-04-14 11:22:41 +08:00
exclude_modules=exclude_modules,
2022-02-14 21:19:03 +08:00
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
self.delta_params = nn.ParameterList()
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module(self, module: nn.Module, key: str):
_, _, ref = self.find_module(module, key)
self.modify_module(ref)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module(self,
2022-04-14 11:22:41 +08:00
module: nn.Module,
2022-02-14 21:19:03 +08:00
):
if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
if isinstance(module, nn.Linear):
self.add_bias_to_linear(module)
else:
# for example, layer_norms, lm_heads.
self.add_bias_to_others(module)
else:
# for the non-leaf modules, by default it will add bias only to the linear submodules.
for n, c in module.named_modules():
if isinstance(c, nn.Linear) or isinstance(c, nn.LayerNorm):
2022-02-14 21:19:03 +08:00
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
else:
pass
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def add_bias_to_linear(self, c):
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def add_bias_to_others(self, c):
new_bias = BiasLayer()
2022-02-20 17:23:31 +08:00
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
2022-02-14 21:19:03 +08:00
# the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@staticmethod
def _reset_bias_parameters(linear_module):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(linear_module.bias, -bound, bound)
def detach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def attach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError