OpenDeltaMirror/opendelta/delta_models/bitfit.py

201 lines
7.8 KiB
Python

from typing import Optional, Union
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase, is_leaf_module
import torch.nn as nn
import torch
from torch.nn import init
import math
from opendelta import BaseDeltaConfig
import opendelta.utils.logging as logging
logger = logging.get_logger(__name__)
class BitFitConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~BitFitModel`
"""
def __init__(
self,
**kwargs
):
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module):
def __init__(self, init_method="zero"):
super().__init__()
self.init_method=init_method
self.instantiated = False
def instantiate(self, hidden_dim):
if self.init_method == "zero":
self.bias = nn.Parameter(torch.zeros(hidden_dim))
else:
raise NotImplementedError
self.instantiated = True
def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension.
In most cases, it is correct. However, be aware of the possibility that the presumption
doesn't hold.
"""
if isinstance(output, tuple):
hiddens = output[0]
elif isinstance(output, torch.Tensor):
hiddens = output
else:
raise TypeError
if not self.instantiated:
self.hidden_dim = hiddens.shape[-1]
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
modified_output = hiddens + self.bias
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = modified_output
else:
raise TypeError
return output
class BitFitModel(DeltaBase):
r""" The implementation of `BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models <https://arxiv.org/abs/2106.10199>`_ .
Unfreeze bias term (or add bias term if bias term is absent in the backbone, e.g. T5) to the modules of
a transformer block.
.. note::
**Broadcast to Submodule**: We modify all potential positions of the specified
``modified_modules``. That is to say, if we specify ``attn`` in the modified_modules, then all position
including the q, k, v and out linear layer of the attention layer are added bias layer (or unfreezing).
The potential position is determined according to equation (1)-(5) and the previous three
equations.
class attributes:
- default_modified_modules = ["attn", "ff", "layer_norm","lm_head.proj"] According to the paper and the
implementation in `Compacter's baseline <https://github.com/rabeehk/compacter>`_ , we modify the
bias term in the above modules.
- delta_type = "bitfit"
Args:
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
"""
config_class = BitFitConfig
delta_type = "bitfit"
default_modified_modules = ["attn", "ff", "layer_norm","lm_head.proj"] # modify all the bias parameter in attention and feed-forward layer.
def __init__(self,
backbone_model: nn.Module,
modified_modules: Optional[List[str]] = None,
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
DeltaBase.__init__(self,
backbone_model,
modified_modules=modified_modules,
exclude_modules=exclude_modules,
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
self.delta_params = nn.ParameterList()
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
def update_module(self, module: nn.Module, key: str):
_, _, ref = self.find_module(module, key)
self.modify_module(ref)
def modify_module(self,
module: nn.Module,
):
if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
if isinstance(module, nn.Linear):
self.add_bias_to_linear(module)
else:
# for example, layer_norms, lm_heads.
self.add_bias_to_others(module)
else:
# for the non-leaf modules, by default it will add bias only to the linear submodules.
for n, c in module.named_modules():
if isinstance(c, nn.Linear):
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
else:
pass
def add_bias_to_linear(self, c):
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
def add_bias_to_others(self, c):
new_bias = BiasLayer()
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
# the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
@staticmethod
def _reset_bias_parameters(linear_module):
fan_in, _ = init._calculate_fan_in_and_fan_out(linear_module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(linear_module.bias, -bound, bound)
def detach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError
def attach(self, module):
r"""Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError