2022-02-14 21:19:03 +08:00
from typing import Optional , Union
from opendelta . utils . signature import get_arg_names_inside_func
from opendelta . utils . name_based_addressing import *
from opendelta . basemodel import DeltaBase , is_leaf_module
import torch . nn as nn
import torch
from torch . nn import init
import math
from opendelta import BaseDeltaConfig
import opendelta . utils . logging as logging
logger = logging . get_logger ( __name__ )
class BitFitConfig ( BaseDeltaConfig ) :
r """
This is the configuration class to store the configuration of a : py : class : ` ~ BitFitModel `
"""
def __init__ (
2022-04-14 11:22:41 +08:00
self ,
2022-02-14 21:19:03 +08:00
* * kwargs
2022-04-14 11:22:41 +08:00
) :
2022-02-14 21:19:03 +08:00
super ( ) . __init__ ( * * kwargs )
arg_names = get_arg_names_inside_func ( self . __init__ )
for arg_name in arg_names :
if not hasattr ( self , arg_name ) : # the arg has not been registered in parent config
setattr ( self , arg_name , locals ( ) [ arg_name ] )
class BiasLayer ( nn . Module ) :
def __init__ ( self , init_method = " zero " ) :
super ( ) . __init__ ( )
self . init_method = init_method
self . instantiated = False
def instantiate ( self , hidden_dim ) :
if self . init_method == " zero " :
self . bias = nn . Parameter ( torch . zeros ( hidden_dim ) )
else :
raise NotImplementedError
self . instantiated = True
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def post_forward ( self , output ) :
r """ Presuming the first argument is the tensor to add bias along the last dimension.
In most cases , it is correct . However , be aware of the possibility that the presumption
2022-04-14 11:22:41 +08:00
doesn ' t hold.
2022-02-14 21:19:03 +08:00
"""
if isinstance ( output , tuple ) :
hiddens = output [ 0 ]
elif isinstance ( output , torch . Tensor ) :
hiddens = output
else :
raise TypeError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if not self . instantiated :
self . hidden_dim = hiddens . shape [ - 1 ]
logger . debug ( f " Got hidden dim hidden_dim { self . hidden_dim } " )
self . instantiate ( hidden_dim = self . hidden_dim )
modified_output = hiddens + self . bias
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if isinstance ( output , tuple ) :
output = ( modified_output , ) + output [ 1 : ]
elif isinstance ( output , torch . Tensor ) :
output = modified_output
else :
raise TypeError
return output
class BitFitModel ( DeltaBase ) :
r """ The implementation of `BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models <https://arxiv.org/abs/2106.10199>`_ .
Unfreeze bias term ( or add bias term if bias term is absent in the backbone , e . g . T5 ) to the modules of
2022-04-14 11:22:41 +08:00
a transformer block .
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
. . note : :
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
* * Broadcast to Submodule * * : We modify all potential positions of the specified
2022-02-14 21:19:03 +08:00
` ` modified_modules ` ` . That is to say , if we specify ` ` attn ` ` in the modified_modules , then all position
including the q , k , v and out linear layer of the attention layer are added bias layer ( or unfreezing ) .
2022-04-14 11:22:41 +08:00
The potential position is determined according to equation ( 1 ) - ( 5 ) and the previous three
2022-02-14 21:19:03 +08:00
equations .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
class attributes :
2022-04-14 11:22:41 +08:00
- default_modified_modules = [ " attn " , " ff " , " layer_norm " , " lm_head.proj " ] According to the paper and the
2022-02-14 21:19:03 +08:00
implementation in ` Compacter ' s baseline <https://github.com/rabeehk/compacter>`_ , we modify the
2022-04-14 11:22:41 +08:00
bias term in the above modules .
2022-02-14 21:19:03 +08:00
- delta_type = " bitfit "
Args :
2022-04-14 11:22:41 +08:00
backbone_model ( : obj : ` transformers . PretrainedModels ` ) : The backbone model to be modified .
2022-02-14 21:19:03 +08:00
modified_modules ( : obj : ` List [ str ] ` ) : For prefix tuning , the it must refer to an attention layer ( Currently , only
the implemented ones )
unfrozen_modules ( : obj : ` List [ str ] ` , * optional * , default to : obj : ` None ` ) : The modules that should be unfrozen
together with the prefix parameters .
2022-03-19 15:04:42 +08:00
common_structure ( : obj : ` bool ` ) : whether using name - based addressing with a common structure mapping .
2022-02-14 21:19:03 +08:00
"""
config_class = BitFitConfig
delta_type = " bitfit "
2022-10-12 01:36:38 +08:00
default_modified_modules = [ " attn@ " , " ff@ " , " layer_norm@ " , " lm_head@.proj@ " ] # modify all the bias parameter in attention and feed-forward layer.
_need_pseudo_data = False
2022-02-14 21:19:03 +08:00
def __init__ ( self ,
2022-04-14 11:22:41 +08:00
backbone_model : nn . Module ,
2022-03-19 15:04:42 +08:00
modified_modules : Optional [ List [ str ] ] = None ,
2022-04-14 11:22:41 +08:00
exclude_modules : Optional [ List [ str ] ] = None ,
2022-03-19 15:04:42 +08:00
unfrozen_modules : Optional [ List [ str ] ] = None ,
2022-02-14 21:19:03 +08:00
common_structure : Optional [ bool ] = None ,
interactive_modify : Optional [ Union [ bool , int ] ] = False ,
) :
2022-04-14 11:22:41 +08:00
DeltaBase . __init__ ( self ,
backbone_model ,
2022-02-14 21:19:03 +08:00
modified_modules = modified_modules ,
2022-04-14 11:22:41 +08:00
exclude_modules = exclude_modules ,
2022-02-14 21:19:03 +08:00
unfrozen_modules = unfrozen_modules ,
common_structure = common_structure ,
interactive_modify = interactive_modify ,
)
arg_names = get_arg_names_inside_func ( self . __init__ )
for arg_name in arg_names :
if not hasattr ( self , arg_name ) : # not registered in parent class
setattr ( self , arg_name , locals ( ) [ arg_name ] )
self . delta_params = nn . ParameterList ( )
self . delta_modules = nn . ModuleList ( )
self . add_all_delta_to_backbone ( self . backbone_model ,
self . modified_modules ,
)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module ( self , module : nn . Module , key : str ) :
_ , _ , ref = self . find_module ( module , key )
self . modify_module ( ref )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module ( self ,
2022-04-14 11:22:41 +08:00
module : nn . Module ,
2022-02-14 21:19:03 +08:00
) :
if is_leaf_module ( module ) :
# if it is a leaf module, add bias to it regardless of its type.
if isinstance ( module , nn . Linear ) :
self . add_bias_to_linear ( module )
else :
# for example, layer_norms, lm_heads.
self . add_bias_to_others ( module )
else :
# for the non-leaf modules, by default it will add bias only to the linear submodules.
for n , c in module . named_modules ( ) :
2022-10-12 01:36:38 +08:00
if isinstance ( c , nn . Linear ) or isinstance ( c , nn . LayerNorm ) :
2022-02-14 21:19:03 +08:00
if c . bias is None :
bias = nn . Parameter ( torch . empty ( c . out_features ) , requires_grad = True )
c . register_parameter ( ' bias ' , bias )
self . _reset_bias_parameters ( c )
self . delta_params . append ( bias )
else :
c . bias . requires_grad = True
self . delta_params . append ( c . bias )
else :
pass
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def add_bias_to_linear ( self , c ) :
if c . bias is None :
bias = nn . Parameter ( torch . empty ( c . out_features ) , requires_grad = True )
c . register_parameter ( ' bias ' , bias )
self . _reset_bias_parameters ( c )
self . delta_params . append ( bias )
else :
c . bias . requires_grad = True
self . delta_params . append ( c . bias )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def add_bias_to_others ( self , c ) :
new_bias = BiasLayer ( )
2022-02-20 17:23:31 +08:00
self . insert_sequential_module ( c , delta_module = new_bias , delta_name = " bitfit " ) # name shouldn't be `bias` here, since
2022-02-14 21:19:03 +08:00
# the name `bias` is reserved for some module such as roberta's LayerNorm.
self . delta_modules . append ( new_bias )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@staticmethod
def _reset_bias_parameters ( linear_module ) :
fan_in , _ = init . _calculate_fan_in_and_fan_out ( linear_module . weight )
bound = 1 / math . sqrt ( fan_in ) if fan_in > 0 else 0
init . uniform_ ( linear_module . bias , - bound , bound )
def detach ( self , module ) :
r """ Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def attach ( self , module ) :
r """ Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError