2022-02-14 21:19:03 +08:00
from typing import Optional , Union
from opendelta . utils . signature import get_arg_names_inside_func
from opendelta . utils . name_based_addressing import *
from opendelta . basemodel import DeltaBase , is_leaf_module
2022-09-03 18:12:12 +08:00
from opendelta . utils . cuda import get_device , get_dtype
2022-02-14 21:19:03 +08:00
import torch . nn as nn
import torch
from torch . nn import init
import math
from opendelta import BaseDeltaConfig
import opendelta . utils . logging as logging
logger = logging . get_logger ( __name__ )
class BitFitConfig ( BaseDeltaConfig ) :
r """
This is the configuration class to store the configuration of a : py : class : ` ~ BitFitModel `
"""
def __init__ (
2022-04-14 11:22:41 +08:00
self ,
2022-02-14 21:19:03 +08:00
* * kwargs
2022-04-14 11:22:41 +08:00
) :
2022-02-14 21:19:03 +08:00
super ( ) . __init__ ( * * kwargs )
arg_names = get_arg_names_inside_func ( self . __init__ )
for arg_name in arg_names :
if not hasattr ( self , arg_name ) : # the arg has not been registered in parent config
setattr ( self , arg_name , locals ( ) [ arg_name ] )
class BiasLayer ( nn . Module ) :
2022-09-03 18:12:12 +08:00
def __init__ ( self , init_method = " zero " , dtype = None , device = None ) :
2022-02-14 21:19:03 +08:00
super ( ) . __init__ ( )
self . init_method = init_method
self . instantiated = False
2022-09-03 18:12:12 +08:00
self . dtype = dtype
self . device = device
2022-02-14 21:19:03 +08:00
def instantiate ( self , hidden_dim ) :
if self . init_method == " zero " :
2022-09-03 18:12:12 +08:00
self . bias = nn . Parameter ( torch . zeros ( hidden_dim , dtype = self . dtype , device = self . device ) )
2022-02-14 21:19:03 +08:00
else :
raise NotImplementedError
self . instantiated = True
2022-09-03 18:12:12 +08:00
try :
import bmtrain as bmt
self . bias = bmt . BMTrainModelWrapper ( self . bias )
except :
pass
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def post_forward ( self , output ) :
r """ Presuming the first argument is the tensor to add bias along the last dimension.
In most cases , it is correct . However , be aware of the possibility that the presumption
2022-04-14 11:22:41 +08:00
doesn ' t hold.
2022-02-14 21:19:03 +08:00
"""
if isinstance ( output , tuple ) :
hiddens = output [ 0 ]
elif isinstance ( output , torch . Tensor ) :
hiddens = output
else :
raise TypeError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if not self . instantiated :
self . hidden_dim = hiddens . shape [ - 1 ]
logger . debug ( f " Got hidden dim hidden_dim { self . hidden_dim } " )
self . instantiate ( hidden_dim = self . hidden_dim )
modified_output = hiddens + self . bias
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if isinstance ( output , tuple ) :
output = ( modified_output , ) + output [ 1 : ]
elif isinstance ( output , torch . Tensor ) :
output = modified_output
else :
raise TypeError
return output
class BitFitModel ( DeltaBase ) :
r """ The implementation of `BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models <https://arxiv.org/abs/2106.10199>`_ .
Unfreeze bias term ( or add bias term if bias term is absent in the backbone , e . g . T5 ) to the modules of
2022-04-14 11:22:41 +08:00
a transformer block .
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
. . note : :
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
* * Broadcast to Submodule * * : We modify all potential positions of the specified
2022-02-14 21:19:03 +08:00
` ` modified_modules ` ` . That is to say , if we specify ` ` attn ` ` in the modified_modules , then all position
including the q , k , v and out linear layer of the attention layer are added bias layer ( or unfreezing ) .
2022-04-14 11:22:41 +08:00
The potential position is determined according to equation ( 1 ) - ( 5 ) and the previous three
2022-02-14 21:19:03 +08:00
equations .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
class attributes :
2022-04-14 11:22:41 +08:00
- default_modified_modules = [ " attn " , " ff " , " layer_norm " , " lm_head.proj " ] According to the paper and the
2022-02-14 21:19:03 +08:00
implementation in ` Compacter ' s baseline <https://github.com/rabeehk/compacter>`_ , we modify the
2022-04-14 11:22:41 +08:00
bias term in the above modules .
2022-02-14 21:19:03 +08:00
- delta_type = " bitfit "
Args :
2022-04-14 11:22:41 +08:00
backbone_model ( : obj : ` transformers . PretrainedModels ` ) : The backbone model to be modified .
2022-02-14 21:19:03 +08:00
modified_modules ( : obj : ` List [ str ] ` ) : For prefix tuning , the it must refer to an attention layer ( Currently , only
the implemented ones )
unfrozen_modules ( : obj : ` List [ str ] ` , * optional * , default to : obj : ` None ` ) : The modules that should be unfrozen
together with the prefix parameters .
2022-03-19 15:04:42 +08:00
common_structure ( : obj : ` bool ` ) : whether using name - based addressing with a common structure mapping .
2022-02-14 21:19:03 +08:00
"""
config_class = BitFitConfig
delta_type = " bitfit "
2022-10-12 01:36:38 +08:00
default_modified_modules = [ " attn@ " , " ff@ " , " layer_norm@ " , " lm_head@.proj@ " ] # modify all the bias parameter in attention and feed-forward layer.
2022-10-23 16:42:21 +08:00
_supported_backends = [ ' hf ' ]
2022-10-12 01:36:38 +08:00
_need_pseudo_data = False
2022-02-14 21:19:03 +08:00
def __init__ ( self ,
2022-04-14 11:22:41 +08:00
backbone_model : nn . Module ,
2022-03-19 15:04:42 +08:00
modified_modules : Optional [ List [ str ] ] = None ,
2022-04-14 11:22:41 +08:00
exclude_modules : Optional [ List [ str ] ] = None ,
2022-03-19 15:04:42 +08:00
unfrozen_modules : Optional [ List [ str ] ] = None ,
2022-02-14 21:19:03 +08:00
common_structure : Optional [ bool ] = None ,
interactive_modify : Optional [ Union [ bool , int ] ] = False ,
2022-10-23 16:42:21 +08:00
backend : Optional [ str ] = " hf " ,
2022-02-14 21:19:03 +08:00
) :
2022-04-14 11:22:41 +08:00
DeltaBase . __init__ ( self ,
backbone_model ,
2022-02-14 21:19:03 +08:00
modified_modules = modified_modules ,
2022-04-14 11:22:41 +08:00
exclude_modules = exclude_modules ,
2022-02-14 21:19:03 +08:00
unfrozen_modules = unfrozen_modules ,
common_structure = common_structure ,
interactive_modify = interactive_modify ,
2022-10-23 16:42:21 +08:00
backend = backend ,
2022-02-14 21:19:03 +08:00
)
arg_names = get_arg_names_inside_func ( self . __init__ )
for arg_name in arg_names :
if not hasattr ( self , arg_name ) : # not registered in parent class
setattr ( self , arg_name , locals ( ) [ arg_name ] )
self . delta_params = nn . ParameterList ( )
self . delta_modules = nn . ModuleList ( )
self . add_all_delta_to_backbone ( self . backbone_model ,
2022-10-20 18:16:05 +08:00
self . modified_modules )
2022-10-23 16:42:21 +08:00
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module ( self , module : nn . Module , key : str ) :
_ , _ , ref = self . find_module ( module , key )
self . modify_module ( ref )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module ( self ,
2022-04-14 11:22:41 +08:00
module : nn . Module ,
2022-02-14 21:19:03 +08:00
) :
if is_leaf_module ( module ) :
# if it is a leaf module, add bias to it regardless of its type.
2022-10-20 18:16:05 +08:00
# if self.check_linear(module):
# self.add_bias_to_linear(module)
2022-10-23 16:42:21 +08:00
if self . backend_mapping . check_type ( module , ' linear ' ) or \
self . backend_mapping . check_type ( module , ' layer_norm ' ) :
2022-10-14 23:15:38 +08:00
self . add_bias_to_modules_have_bias_or_known_type ( module )
2022-02-14 21:19:03 +08:00
else :
# for example, layer_norms, lm_heads.
self . add_bias_to_others ( module )
else :
for n , c in module . named_modules ( ) :
2022-10-14 23:15:38 +08:00
self . add_bias_to_modules_have_bias_or_known_type ( c )
2022-10-20 18:16:05 +08:00
# if self.check_linear(c):
# self.add_bias_to_linear(c)
# else:
# pass
# def add_bias_to_linear(self, c):
# if c.bias is None:
# bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
# self._reset_bias_parameters(c)
# try:
# import bmtrain as bmt
# bias = bmt.BMTrainModelWrapper(bias)
# except:
# pass
# c.register_parameter('bias', bias)
# self.delta_params.append(bias)
# else:
# self.add_bias_to_modules_have_bias_or_known_type(c)
2022-10-14 23:15:38 +08:00
def add_bias_to_modules_have_bias_or_known_type ( self , c ) :
''' If it has bias, unfreeze it.
If it doesn ' t have bias: if it is Linear of LN, add to it, else pass.
'''
if ' bias ' in [ n for n , p in c . named_parameters ( ) ] :
2022-02-14 21:19:03 +08:00
c . bias . requires_grad = True
self . delta_params . append ( c . bias )
2022-10-14 23:15:38 +08:00
else :
2022-10-23 16:42:21 +08:00
if self . backend_mapping . check_type ( c , ' linear ' ) or \
self . backend_mapping . check_type ( c , ' layer_norm ' ) :
2022-10-14 23:15:38 +08:00
bias = nn . Parameter ( torch . empty ( c . out_features ) , requires_grad = True )
2022-10-20 18:16:05 +08:00
2022-10-23 16:42:21 +08:00
self . _reset_bias_parameters ( c )
if self . backend == ' bmt ' :
2022-10-20 18:16:05 +08:00
import bmtrain as bmt
bias = bmt . BMTrainModelWrapper ( bias )
2022-10-23 16:42:21 +08:00
2022-10-14 23:15:38 +08:00
c . register_parameter ( ' bias ' , bias )
self . delta_params . append ( bias )
2022-04-14 11:22:41 +08:00
2022-10-23 16:42:21 +08:00
def add_bias_to_others ( self , c ) :
new_bias = BiasLayer ( dtype = get_dtype ( c ) , device = get_device ( c ) ) # TODO: bmtrain?
if self . backend == ' bmt ' :
import bmtrain as bmt
new_bias = bmt . BMTrainModelWrapper ( new_bias )
2022-10-20 18:16:05 +08:00
self . insert_sequential_module ( c , delta_module = new_bias , delta_name = " bitfit " ) # name shouldn't be `bias` here, since the name `bias` is reserved for some module such as roberta's LayerNorm.
2022-02-14 21:19:03 +08:00
self . delta_modules . append ( new_bias )
@staticmethod
def _reset_bias_parameters ( linear_module ) :
fan_in , _ = init . _calculate_fan_in_and_fan_out ( linear_module . weight )
bound = 1 / math . sqrt ( fan_in ) if fan_in > 0 else 0
init . uniform_ ( linear_module . bias , - bound , bound )
def detach ( self , module ) :
r """ Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def attach ( self , module ) :
r """ Not implemented for BitFit yet. Please wait for the next version.
"""
raise NotImplementedError