2022-02-14 21:19:03 +08:00
from collections import OrderedDict
from multiprocessing . sharedctypes import Value
import os
2022-10-12 01:36:38 +08:00
from turtle import back
2022-02-14 21:19:03 +08:00
from opendelta . delta_configs import BaseDeltaConfig
from opendelta . utils . model_md5 import gen_model_hash
from opendelta . utils . signature import get_arg_names , signature
from typing import Optional , Union
from opendelta . utils . cuda import get_device
from opendelta . utils . name_based_addressing import *
import torch . nn as nn
import torch
from functools import wraps
# from decorator import decorate
from opendelta . utils . decorate import decorate
from opendelta . utils . structure_mapping import transform
from transformers . file_utils import PushToHubMixin
from transformers . deepspeed import deepspeed_config , is_deepspeed_zero3_enabled
from opendelta import SaveLoadMixin
from opendelta import logging
from opendelta . utils . structure_mapping import CommonStructureMap
from opendelta . utils . interactive . web import interactive
from opendelta . utils . data_parallel import new_replicate_for_data_parallel
2022-06-07 01:52:32 +08:00
from opendelta . utils . cuda import move_dict_to_cuda
2022-10-12 01:36:38 +08:00
import sys
2022-06-07 01:52:32 +08:00
2022-02-14 21:19:03 +08:00
logger = logging . get_logger ( __name__ )
def is_leaf_module ( module ) :
r """ Whether the module is a leaf module
"""
2022-10-12 01:36:38 +08:00
return len ( [ n for n , _ in module . named_children ( ) ] ) == 0
2022-02-14 21:19:03 +08:00
def non_module_param ( module : nn . Module ) :
module_names = [ n for n , _ in module . named_modules ( ) ]
ret = [ ]
for n , p in module . named_parameters ( ) :
if not is_child_key ( n , module_names ) :
ret . append ( ( n , p ) )
return ret
class DeltaBase ( nn . Module , SaveLoadMixin ) :
2022-04-14 11:22:41 +08:00
r """ This is the base class for all delta models. It provides four simple but effective functionalities
2022-02-14 21:19:03 +08:00
for building the delta model :
2022-04-14 11:22:41 +08:00
#. addressing a module inside the backbone model using a minimal description key.
#. provide the interface for modifying and inserting model which keeps the docs/IO the same as the module
2022-02-14 21:19:03 +08:00
before modification .
2022-04-14 11:22:41 +08:00
#. pass a pseudo input to determine the inter dimension of the delta models.
#. freeze a part of model parameters according to key.
It also provides unified interface for model loading and saving .
2022-02-14 21:19:03 +08:00
Class attributes ( overridden by derived classes ) :
- delta_type ( : obj : ` str ` ) : the name of the delta modules , used to create the correct : class : ` opendelta . AutoDeltaModel ` .
2022-04-14 11:22:41 +08:00
- config_class ( : class : ` BaseDeltaConfig ` ) : The corresponding config model
2022-02-14 21:19:03 +08:00
Args :
2022-04-14 11:22:41 +08:00
backbone_model ( : obj : ` nn . Module ` , * required * ) : backbone model that the delta models are build opon . The modification to the
2022-02-14 21:19:03 +08:00
backbone model are in place .
2022-04-14 11:22:41 +08:00
modified_modules ( : obj : ` List [ str ] ` , * optional * , default to : obj : ` None ` ) : The modules are subjected to update .
2022-02-14 21:19:03 +08:00
. . note : :
leave this argument : obj : ` None ` will make the delta model return to the default setting , which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper . In this setting , the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules .
2022-04-14 11:22:41 +08:00
exclude_modules ( : obj : ` str ` , * optional * , default to : obj : ` None ` ) : The modules starts with these strings will be excluded in modification .
Note that currently only plain text ( no regular expression ) is supported .
unfrozen_modules ( : obj : ` str ` , * optional * , default to : obj : ` None ` ) : The modules that are * * not * * frozen when freezing the main part of the model .
2022-02-14 21:19:03 +08:00
registraction_name ( : obj : ` str ` , * optional * , default to ` ` " deltas " ` ` ) : The root name of the delta models when
2022-04-14 11:22:41 +08:00
attached to the backbone model .
2022-02-14 21:19:03 +08:00
common_structure ( : obj : ` bool ` , * optional * , default to : obj : ` None ` ) : Whether use the common structure mapping to specify the
modified_modules . i . e . , if common_structure = True , then we use a common [ " attn " ] for attention module in different models .
2022-04-14 11:22:41 +08:00
We DO NOT recommend manually set ` ` common_structure ` ` to ` ` true ` ` by yourself unless you are using delta
among multiple backbones and don ' t want to modify the code.
2022-02-14 21:19:03 +08:00
interactive_modify ( : obj : ` bool ` or : obj : ` int ` , * optional * , default to : obj : ` None ` ) : Whether to use interactive modification .
By setting to : obj : ` int ` can specify the port of web server .
"""
delta_type = " "
default_modified_modules = [ ]
2022-04-14 11:22:41 +08:00
default_exclude_modules = [ " lm_head " ]
2022-02-14 21:19:03 +08:00
config_class = BaseDeltaConfig
default_unfrozen_modules = [ " deltas " ]
2022-10-12 01:36:38 +08:00
_need_pseudo_data = True
2022-04-14 11:22:41 +08:00
def __init__ ( self ,
2022-02-14 21:19:03 +08:00
backbone_model : nn . Module ,
modified_modules : Optional [ List [ str ] ] = None ,
2022-04-14 11:22:41 +08:00
exclude_modules : Optional [ List [ str ] ] = None ,
2022-02-14 21:19:03 +08:00
unfrozen_modules : Optional [ List [ str ] ] = None ,
interactive_modify : Optional [ Union [ bool , int ] ] = False ,
2022-03-19 15:04:42 +08:00
common_structure : Optional [ bool ] = False ,
2022-02-14 21:19:03 +08:00
) :
nn . Module . __init__ ( self )
# register the backbone model after init using self.__dict__ method to avoid adding backbone_model
# to the modules of the delta model.
self . __dict__ [ " backbone_model " ] = backbone_model
2022-04-14 11:22:41 +08:00
if modified_modules is None and exclude_modules is None :
2022-02-14 21:19:03 +08:00
if interactive_modify :
if isinstance ( interactive_modify , bool ) and interactive_modify == True :
self . modified_modules = interactive ( backbone_model )
else :
self . modified_modules = interactive ( backbone_model , port = interactive_modify )
self . common_structure = False
2022-06-07 01:52:32 +08:00
self . exclude_modules = self . default_exclude_modules
2022-02-14 21:19:03 +08:00
else :
self . modified_modules = self . default_modified_modules
self . common_structure = True
2022-04-14 11:22:41 +08:00
self . exclude_modules = self . default_exclude_modules
2022-02-14 21:19:03 +08:00
else :
if interactive_modify :
2022-04-14 11:22:41 +08:00
raise ValueError ( " Use modified_modules(or exclude modules) and interactive_modify at the same time is not supported " )
if modified_modules is not None :
self . modified_modules = modified_modules
else :
self . modified_modules = self . default_modified_modules
if exclude_modules is not None :
self . exclude_modules = exclude_modules
else :
self . exclude_modules = self . default_exclude_modules
2022-02-14 21:19:03 +08:00
self . common_structure = common_structure
if self . common_structure :
2022-10-12 01:36:38 +08:00
self . structure_mapping = CommonStructureMap ( self . backbone_model )
2022-02-14 21:19:03 +08:00
else :
self . structure_mapping = None
if unfrozen_modules is None :
self . unfrozen_modules = self . default_unfrozen_modules
if self . common_structure and self . structure_mapping is None :
raise RuntimeError ( " Using common structure but the structure mapping is None " )
2022-04-14 11:22:41 +08:00
2022-10-14 23:15:38 +08:00
def forward ( self , * args , * * kwargs ) - > RuntimeError :
2022-04-14 11:22:41 +08:00
r """
2022-02-14 21:19:03 +08:00
. . warning : :
Removed method . As the model is a delta model , which should be attached to a backbone model \
and can ' t forward any data by itself. Please using the backbone model ' s forward function \
after attach the delta model to the backbone .
"""
raise RuntimeError ( " This is a delta model, which should be attached to a backbone model \
and can ' t forward any data by itself. Please using the backbone model ' s forward function \
after attach the delta model to the backbone . " )
@classmethod
def from_config ( cls , config : Union [ BaseDeltaConfig , dict ] , backbone_model : nn . Module , check_hash = True , * * kwargs ) :
r """ Initialize a delta model from a config object or a dict containing the configs. To temperarily change
a value in the config , pass it through kwargs . If the config has a backbone model ' s hash, which means it is
a finetuned delta model ' s config, then we will compare the hash in the config and the newly caculated to ensure
the finedtuned delta model is trained on the passed backbone_model . Pass ` ` check_hash = False ` ` to disable the
checking .
Args :
2022-04-14 11:22:41 +08:00
config ( : obj : ` BaseDeltaConfig ` or : obj : ` dict ` ) A config object or a dict that contains the necessary value to
2022-02-14 21:19:03 +08:00
initialize the delta model .
2022-04-14 11:22:41 +08:00
backbone_model ( : obj : ` nn . Module ` ) A pytorch module that will be pass into the delta model as the backbone
2022-02-14 21:19:03 +08:00
model . modifications will be made in place in the backbone model .
2022-04-14 11:22:41 +08:00
check_hash ( : obj : ` bool ` , default to ` ` True ` ` ) Whether to check hash of the backbone model and the config ' s
backbone hash .
2022-02-14 21:19:03 +08:00
kwargs : Any configurations that are passed to update the config object . #TODO unit test needed.
"""
supported_keys = get_arg_names ( cls . __init__ ) + get_arg_names ( DeltaBase . __init__ )
config_dict = config . to_dict ( )
for key in list ( config_dict . keys ( ) ) :
if key not in supported_keys :
config_dict . pop ( key )
return cls ( backbone_model , * * config_dict )
2022-04-14 11:22:41 +08:00
def add_all_delta_to_backbone ( self ,
backbone : nn . Module ,
2022-02-14 21:19:03 +08:00
modified_modules : List [ str ] ,
) - > nn . Module :
r """ The main function to add delta models to the backbone model based on the :obj:`modified_modules`.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args :
2022-04-14 11:22:41 +08:00
backbone_model ( : obj : ` nn . Module ` , * required * ) backbone model that the delta models are build opon . The
2022-02-14 21:19:03 +08:00
modification to the backbone model are in place .
2022-04-14 11:22:41 +08:00
modified_modules ( : obj : ` List [ str ] ` , * optional * , default to : obj : ` None ` ) The modules are subjected to update .
2022-02-14 21:19:03 +08:00
leave this argument : obj : ` None ` will make the delta model return to the default setting , which add the delta
2022-04-14 11:22:41 +08:00
models to the position experimented the paper . In this setting , the common structure mapping is loaded to
2022-02-14 21:19:03 +08:00
addressing the corresponding modules .
Returns :
: obj : ` nn . Module ` The modified backbone model .
"""
self . plm_total_params = sum ( p . numel ( ) for p in backbone . parameters ( ) )
# create a new key list to avoid recursion.
2022-04-14 11:22:41 +08:00
backbone_key_list = [ key for key , _ in backbone . named_modules ( ) ]
2022-02-14 21:19:03 +08:00
for key in backbone_key_list :
2022-10-14 23:15:38 +08:00
print ( key )
if self . find_key ( key , modified_modules ) :
print ( " found! " )
2022-02-14 21:19:03 +08:00
self . update_module ( backbone , key )
2022-10-12 01:36:38 +08:00
if self . _need_pseudo_data :
2022-10-14 23:15:38 +08:00
self . _pseudo_data_to_instantiate ( backbone )
2022-10-12 01:36:38 +08:00
2022-02-14 21:19:03 +08:00
# mark the paratmers that are the delta parameters for easily displaying the delta_paramters.
self . mark_as_delta ( )
return backbone
2022-10-12 01:36:38 +08:00
def _pseudo_data_to_instantiate ( self , backbone : Optional [ nn . Module ] = None ) :
if self . structure_mapping is None :
self . _pseudo_data_to_instantiate_module ( backbone )
else :
for key in self . structure_mapping . matched_pairs :
2022-10-14 23:15:38 +08:00
if key == " " :
submodule = backbone
else :
_ , _ , submodule = self . find_module ( backbone , key )
2022-10-12 01:36:38 +08:00
self . _pseudo_data_to_instantiate_module ( submodule )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def mark_as_delta ( self , module : nn . Module = None , ) :
r """ [NODOC] Mark :obj:`module` ' s all parameters as delta parameters by setting a ``_is_delta`` attribute to each of them.
2022-04-14 11:22:41 +08:00
Generally , it is used after creating the delta modules . By leaving module to : obj : ` None ` , it will mark all the parameters in the
2022-02-14 21:19:03 +08:00
delta model as ` ` _is_delta ` ` .
Args :
module ( : obj : ` nn . Module ` ) : The module to mark as delta .
"""
if module is None :
module = self # all the parameters in the delta model.
for p in module . parameters ( ) :
setattr ( p , " _is_delta " , True )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module ( self , module : nn . Module , key : str ) :
2022-04-14 11:22:41 +08:00
r """ Update a module specified by :obj:`key`. The method is reimplemented in each specific delta model.
2022-02-14 21:19:03 +08:00
"""
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def freeze_module ( self ,
2022-04-14 11:22:41 +08:00
module : Optional [ nn . Module ] = None ,
exclude : Optional [ List [ str ] ] = None ,
set_state_dict : Optional [ bool ] = True ,
2022-02-14 21:19:03 +08:00
) :
r """ Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ` ` _is_delta ` ` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model , ( e . g . , bias term )
Args :
module ( : obj : ` nn . Module ` , * optional * , default to : obj : ` None ` ) : The module of which some parts are frozen .
2022-04-14 11:22:41 +08:00
If left with : obj : ` None ` , the function will the self . backbone_model as the module to be frozen .
exclude ( : obj : ` List [ str ] ` , * optional * , default to ` ` [ " deltas " ] ` ` ) : The parameters that don ' t need to
2022-02-14 21:19:03 +08:00
be freezed . Default to all the delta parameters .
set_state_dict ( : obj : ` bool ` , * optional * , default to : obj : ` True ` ) : Whether setting the backbone model ' s state
dict to all the parameters that still need grad .
2022-04-14 11:22:41 +08:00
prefix ( : obj : ` str ` , * optional * , default to ` ` " " ` ` ) : A parameters that are used for recursive frozen .
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ` ` " " ` ` .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if exclude is None :
exclude = self . unfrozen_modules
if module is None :
module = self . backbone_model
self . _freeze_module_recursive ( module , exclude , " " ) # modify the active state dict that still need grad
if set_state_dict :
self . set_active_state_dict ( module )
def _freeze_module_recursive ( self ,
2022-04-14 11:22:41 +08:00
module : Optional [ nn . Module ] = None ,
2022-02-14 21:19:03 +08:00
exclude : Optional [ List [ str ] ] = None ,
prefix = " " ) :
r """ [NODOC] Freeze the parameters of plm. Leave the parameters in exclude untouched.
2022-04-14 11:22:41 +08:00
deltas module is filtered with ` ` _is_delta ` ` attributes because it may have parameter sharing to the main
2022-02-14 21:19:03 +08:00
model , ( e . g . , bias term )
Args :
module ( : obj : ` nn . Module ` , * optional * , default to : obj : ` None ` ) : The module of which some parts are frozen .
2022-04-14 11:22:41 +08:00
If left with : obj : ` None ` , the function will the self . backbone_model as the module to be frozen .
exclude ( : obj : ` List [ str ] ` , * optional * , default to ` ` [ " deltas " ] ` ` ) : The parameters that don ' t need to
2022-02-14 21:19:03 +08:00
be freezed . Default to all the delta parameters .
set_state_dict ( : obj : ` bool ` , * optional * , default to : obj : ` True ` ) : Whether setting the backbone model ' s state
dict to all the parameters that still need grad .
2022-04-14 11:22:41 +08:00
prefix ( : obj : ` str ` , * optional * , default to ` ` " " ` ` ) : A parameters that are used for recursive frozen .
2022-02-14 21:19:03 +08:00
Should not be changed by passing argument other than ` ` " " ` ` .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
if is_leaf_module ( module ) :
for n , p in module . named_parameters ( ) :
2022-02-15 22:43:28 +08:00
if self . find_key ( " . " . join ( [ prefix , n ] ) , exclude ) :
2022-02-14 21:19:03 +08:00
continue
if " deltas " not in exclude or ( not ( hasattr ( p , " _is_delta " ) and getattr ( p , " _is_delta " ) ) ) :
p . requires_grad = False
2022-04-14 11:22:41 +08:00
return
2022-02-14 21:19:03 +08:00
else :
for n , c in module . named_children ( ) :
2022-02-15 22:43:28 +08:00
if self . find_key ( " . " . join ( [ prefix , n ] ) , exclude ) : # if found, untouch the parameters
2022-02-14 21:19:03 +08:00
continue
else : # firstly freeze the non module params, then go deeper.
params = non_module_param ( module )
for n , p in params :
if " deltas " not in exclude or ( not ( hasattr ( p , " _is_delta " ) and getattr ( p , " _is_delta " ) ) ) :
p . requires_grad = False
self . _freeze_module_recursive ( c , exclude = exclude , prefix = " . " . join ( [ prefix , n ] ) )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
2022-02-15 22:43:28 +08:00
def find_key ( self , key : str , target_list : List [ str ] ) :
2022-04-14 11:22:41 +08:00
r """ Check whether any target string is in the key or in the tail of the key, i.e.,
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Args :
2022-02-15 17:59:53 +08:00
key ( : obj : ` str ` ) : The key ( name ) of a submodule in a ancestor module .
2022-02-14 21:19:03 +08:00
E . g . , model . encoder . layer .0 . attention
2022-02-15 17:59:53 +08:00
target_list ( List [ Union [ : obj : ` str ` , : obj : ` re . Pattern ` ] ] ) : The target list that we try to match ` ` key ` ` with . E . g . , [ " attention " ]
2022-02-14 21:19:03 +08:00
2022-04-14 11:22:41 +08:00
Returns :
2022-02-14 21:19:03 +08:00
: obj : ` bool ` True if the key matchs the target list .
"""
2022-04-14 11:22:41 +08:00
for x in self . exclude_modules :
if key . startswith ( x ) : # start with the excluded key
return False
2022-10-14 23:15:38 +08:00
if self . structure_mapping is not None :
key , virtual_key , in_virtual_order = self . structure_mapping . transform ( key , strict = False )
# currently in_virtual_order not in use, it means that if the common structure designate adding adapter to FFN, it will be add to all submodule of FFN.
2022-02-14 21:19:03 +08:00
if not key :
return False
2022-10-14 23:15:38 +08:00
if virtual_key is None :
2022-02-15 22:43:28 +08:00
return endswith_in ( key , target_list )
2022-10-14 23:15:38 +08:00
else :
return endswith_in ( key , target_list ) or endswith_in ( virtual_key , target_list )
2022-02-14 21:19:03 +08:00
2022-10-12 01:36:38 +08:00
def _pseudo_data_to_instantiate_module ( self , module : Optional [ nn . Module ] = None ) :
r """ Some delta model requires a pseudo-data be passed through the model to understand the dimensionality of each tensor in the computation graph.
( 1 ) The model in the Huggingface Transformers library usually has the so - called ` dummy_inputs ` . We will make use of it .
( 2 ) If the model does not have ` dummy_inputs ` , we will try to create it and throw a warning .
( 3 ) If we encounter an error in ( 2 ) , we will suggest you to create it by passing the dummy_inputs variable .
2022-02-14 21:19:03 +08:00
Args :
module ( : obj : ` nn . Module ` , * optional * , default to : obj : ` None ` ) : The backbone model .
"""
if module is None :
module = self . backbone_model
2022-06-07 01:52:32 +08:00
device = get_device ( module )
2022-10-12 01:36:38 +08:00
_auto_dummy = False
2022-02-14 21:19:03 +08:00
try :
dummy_inputs = module . dummy_inputs
2022-06-07 01:52:32 +08:00
dummy_inputs = move_dict_to_cuda ( dummy_inputs , device )
2022-02-14 21:19:03 +08:00
except AttributeError :
2022-10-12 01:36:38 +08:00
logger . warning ( f " No `dummy_inputs` attribute in { module . __class__ . __name__ } , automatically create `dummy_inputs`. Very likely to encounter error. To set dummy_inputs for your model, please use: `setattr(backbone_model, ' dummy_inputs ' , some_dummy_inputs)` before initializing ` { self . __class__ . __name__ } ` " )
_auto_dummy = True
pass
if _auto_dummy :
_most_simple_input = torch . tensor ( [ [ 0 , 0 ] ] ) . to ( device )
2022-02-14 21:19:03 +08:00
if " decoder_input_ids " in signature ( module . forward ) . args :
2022-10-12 01:36:38 +08:00
dummy_inputs = { " input_ids " : _most_simple_input , " decoder_input_ids " : _most_simple_input }
2022-02-14 21:19:03 +08:00
else :
2022-10-12 01:36:38 +08:00
dummy_inputs = { " input_ids " : _most_simple_input }
_auto_dummy_fail = False
try :
module ( * * dummy_inputs )
except :
_auto_dummy_fail = True
if _auto_dummy_fail :
raise AttributeError ( f " \n The { self . __class__ . __name__ } requires a pseudo-data to be passed through the model to understand the dimensionality of each tensor in the computation graph. \n The automatically created dummy inputs failed. \n The `dummy_inputs` can be any data that make `backbone_model.forward(**dummy_inputs)` succeed. Only the form and shape of the `dummy_inputs` matter. \n \t To set dummy_inputs for your model, please use: `setattr(backbone_model, ' dummy_inputs ' , some_dummy_inputs)` before initializing ` { self . __class__ . __name__ } ` " )
2022-02-14 21:19:03 +08:00
def trainable_parameters_names ( self , module : Optional [ nn . Module ] = None ) :
r """ [NODOC] A small sugar function to return all the trainable parameter ' s name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the trainable paramemters ' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
: obj : ` List [ str ] `
"""
if module is None :
module = self . backbone_model
return [ n for n , p in module . named_parameters ( ) if p . requires_grad ]
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def frozen_parameters_names ( self , module : Optional [ nn . Module ] = None ) :
r """ [NODOC] A small sugar function to return all the frozen parameters ' name in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the frozen paramemters ' name.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
: obj : ` List [ str ] `
"""
if module is None :
module = self . backbone_model
return [ n for n , p in module . named_parameters ( ) if not p . requires_grad ]
def trainable_parameters ( self , module : Optional [ nn . Module ] = None ) :
r """ [NODOC] A small sugar function to return all the frozen parameters in the (by default, backbone) model.
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the frozen paramemters .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
2022-04-14 11:22:41 +08:00
: obj : ` List [ nn . Parameter ] `
2022-02-14 21:19:03 +08:00
"""
if module is None :
module = self
return [ p for n , p in module . named_parameters ( ) if p . requires_grad ]
def num_trainable_parameters ( self , module : Optional [ nn . Module ] = None ) :
2022-04-14 11:22:41 +08:00
r """ [NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate .
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the number of trainable paramemters .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
2022-04-14 11:22:41 +08:00
: obj : ` List [ nn . Parameter ] `
2022-02-14 21:19:03 +08:00
"""
if module is None :
module = self
pnum_tot = 0
for param in module . parameters ( ) :
if param . requires_grad :
pnum_tot + = param . numel ( )
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def num_total_parameters ( self , module : Optional [ nn . Module ] = None ) :
2022-04-14 11:22:41 +08:00
r """ [NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate .
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the number of trainable paramemters .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
2022-04-14 11:22:41 +08:00
: obj : ` List [ nn . Parameter ] `
2022-02-14 21:19:03 +08:00
"""
if module is None :
module = self
pnum_tot = 0
for param in module . parameters ( ) :
pnum_tot + = param . numel ( )
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def find_module ( self , root_module : nn . Module , key : str ) :
r """ Find the module using a key and the root module. Return both the parent reference, the child name and reference.
Args :
root_module ( : obj : ` root_module ` ) : The root_module to find the sub module in
2022-04-14 11:22:41 +08:00
key ( : obj : ` str ` ) : The relative key to the root module .
2022-02-14 21:19:03 +08:00
Returns :
2022-04-14 11:22:41 +08:00
( : obj : ` nn . Module ` , : obj : ` str ` , : obj : ` nn . Module ` ) :
* A reference to the parent module of the target module , mainly for substuting the target module .
2022-02-14 21:19:03 +08:00
* The key of the target module relevant to its parent module
* Target module .
"""
sub_keys = key . split ( " . " )
parent_module = root_module
for sub_key in sub_keys [ : - 1 ] :
parent_module = getattr ( parent_module , sub_key )
module = getattr ( parent_module , sub_keys [ - 1 ] )
return parent_module , sub_keys [ - 1 ] , module
def _register_delta_infos ( self , parent_module , _delta_info ) :
r """ Register the delta infomation.
Automatically incrementing the suffix for repeated delta_names
"""
_delta_infos = getattr ( parent_module , " _delta_infos " , [ ] )
if len ( _delta_infos ) > 0 : # check if duplicated name
list_of_deltas = [ d [ ' delta_name ' ] for d in _delta_infos ]
cur_name = _delta_info [ ' delta_name ' ]
if cur_name in list_of_deltas :
cur_name = cur_name + " _1 "
counter = 1
while cur_name in list_of_deltas :
counter + = 1
cur_name = cur_name . split ( " _ " ) [ 0 ] + " _ " + str ( counter )
_delta_info [ " delta_name " ] = cur_name
_delta_infos . append ( _delta_info )
setattr ( parent_module , " _delta_infos " , _delta_infos )
def replace_module ( self ,
2022-04-14 11:22:41 +08:00
parent_module : nn . Module ,
2022-02-14 21:19:03 +08:00
child_name : str ,
child_module : nn . Module ,
new_module : nn . Module ,
delta_name : Optional [ str ] = " delta " ,
) :
2022-04-14 11:22:41 +08:00
r """ Replace a module ' s child module with the new_module(a delta module). Used by delta method based on direct
2022-02-14 21:19:03 +08:00
replacement , such as : class : ` opendelta . delta_modules . lora . LoraModel ` .
Args :
parent_module ( : obj : ` nn . Module ` ) : The parent module of the replacement .
child_name ( : obj : ` str ` ) : The chird module ' s name, i.e., parent_module.child_name give us child_module
child_module ( : obj : ` nn . Module ` ) : The original child module .
new_module ( : obj : ` nn . Module ` ) : The delta module .
delta_name ( : obj : ` str ` , * optional * , default ot ` ` delta ` ` ) : The name of the delta module , used for recording .
parent_module . delta_name WILL NOT give you the delta module .
"""
self . delta_modules . append ( new_module )
setattr ( parent_module , child_name , new_module )
# register delta info
2022-04-14 11:22:41 +08:00
_delta_info = { " method " : " replace " ,
" delta_module " : new_module ,
2022-02-14 21:19:03 +08:00
" child_name " : child_name ,
" org_module " : child_module ,
" delta_name " : delta_name ,
" delta_belong " : self ,
" state " : " on " }
self . _register_delta_infos ( parent_module = parent_module ,
_delta_info = _delta_info ,
)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def modify_module ( self , module : nn . Module ) :
r """ Modify the inside parameteres of a module. This method will be reimplemented in different
derived class if needed .
"""
raise NotImplementedError
2022-02-20 17:23:31 +08:00
def insert_sequential_module ( self , module , delta_module = None , delta_name = ' delta ' , strict = False , _delta_info = None ) :
2022-04-14 11:22:41 +08:00
r """ insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
2022-02-14 21:19:03 +08:00
function of the original module to firstly pass the arguments into the new module ' s forward function and then pass
2022-04-14 11:22:41 +08:00
it into the original ones . The new module can also be inserted after the original module with similar mechanism .
2022-02-14 21:19:03 +08:00
When implementing the new module , researchers should be aware of the components of arguments of the original module ' s forward function.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Args :
module : ( : obj : ` nn . Module ` ) : The ( sub ) module to inserted a delta module .
delta_module : ( : obj : ` DeltaBase ` ) : The delta module to be inserted .
name : ( : obj : ` str ` , * optional * ) : The name of the delta in the backbone module .
strict : ( : obj : ` bool ` , * optional * ) : Whether to prohibit modify a modified module .
2022-04-14 11:22:41 +08:00
_delta_info ( : obj : ` Dict ` , * optional * ) : Used in attach ( ) , reattach a delta module to backbone . The info of
2022-02-14 21:19:03 +08:00
original delta is passed through ` ` _delta_info ` ` .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
"""
def _caller ( _org_func , org_module , delta_name , * args , * * kwargs ) :
args = args [ 1 : ] # the first argument here is ``self``
delta_module = getattr ( org_module , delta_name )
if hasattr ( delta_module , " pre_forward " ) : # is not None:
args , kwargs = delta_module . pre_forward ( * args , * * kwargs )
ret = _org_func ( * args , * * kwargs )
if hasattr ( delta_module , " post_forward " ) : # is not None:
ret = delta_module . post_forward ( ret )
return ret
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if strict :
if hasattr ( module . forward , " __wrapped__ " ) :
raise RuntimeWarning ( " The forward function might have been wrapped by a decorator, is it intended? " )
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# record info for plug and unplug and nested wrap
if _delta_info is None :
if delta_module is None :
raise RuntimeError ( " delta module can ' t be none to ensure successful replicate of the parent module. " )
2022-04-14 11:22:41 +08:00
_delta_info = { " method " : " insert_sequential " ,
" delta_module " : delta_module ,
2022-02-20 17:23:31 +08:00
" delta_name " : delta_name ,
2022-02-14 21:19:03 +08:00
" delta_belong " : self ,
" state " : " on " }
self . _register_delta_infos ( parent_module = module ,
_delta_info = _delta_info )
else :
delta_module = _delta_info [ " delta_module " ]
2022-02-20 17:23:31 +08:00
delta_name = _delta_info [ " delta_name " ]
2022-02-14 21:19:03 +08:00
setattr ( module , _delta_info [ ' delta_name ' ] , _delta_info [ " delta_module " ] )
new_forward = decorate ( module . forward , _caller , extras = ( module , _delta_info [ ' delta_name ' ] ) , kwsyntax = True ) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module . forward = new_forward . __get__ ( module , type ( module ) ) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
2022-04-14 11:22:41 +08:00
module . _replicate_for_data_parallel = new_replicate_for_data_parallel . __get__ ( module , type ( module ) )
2022-02-14 21:19:03 +08:00
2022-02-20 17:23:31 +08:00
def insert_parallel_module ( self , module , delta_module = None , delta_name = ' delta ' , strict = False , _delta_info = None ) :
2022-04-14 11:22:41 +08:00
""" insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the delta model ' s forward function and set
2022-02-14 21:19:03 +08:00
aside the calculation result . Then combine it with the calculation result output from the backbone module .
When implementing the new module , researchers should be aware of the arguments and keywards of the original module ' s forward function.
2022-02-20 17:23:31 +08:00
Args :
module : ( : obj : ` nn . Module ` ) : The ( sub ) module to inserted a delta module .
delta_module : ( : obj : ` DeltaBase ` ) : The delta module to be inserted .
name : ( : obj : ` str ` , * optional * ) : The name of the delta in the backbone module .
strict : ( : obj : ` bool ` , * optional * ) : Whether to prohibit modify a modified module .
2022-04-14 11:22:41 +08:00
_delta_info ( : obj : ` Dict ` , * optional * ) : Used in attach ( ) , reattach a delta module to backbone . The info of
2022-02-20 17:23:31 +08:00
original delta is passed through ` ` _delta_info ` ` .
2022-02-14 21:19:03 +08:00
"""
2022-02-20 17:23:31 +08:00
def _caller ( _org_func , org_module , delta_name , * args , * * kwargs ) :
args = args [ 1 : ] # the first argument here is ``self``
delta_module = getattr ( org_module , delta_name )
ret_1 = _org_func ( * args , * * kwargs )
ret_2 = delta_module . forward ( * args , * * kwargs )
return ret_1 + ret_2
2022-04-14 11:22:41 +08:00
2022-02-20 17:23:31 +08:00
if strict :
if hasattr ( module . forward , " __wrapped__ " ) :
raise RuntimeWarning ( " The forward function might have been wrapped by a decorator, is it intended? " )
2022-04-14 11:22:41 +08:00
2022-02-20 17:23:31 +08:00
# record info for plug and unplug and nested wrap
if _delta_info is None :
if delta_module is None :
raise RuntimeError ( " delta module can ' t be none to ensure successful replicate of the parent module. " )
2022-04-14 11:22:41 +08:00
_delta_info = { " method " : " insert_parallel " ,
" delta_module " : delta_module ,
2022-02-20 17:23:31 +08:00
" delta_name " : delta_name ,
" delta_belong " : self ,
" state " : " on " }
self . _register_delta_infos ( parent_module = module ,
_delta_info = _delta_info )
else :
delta_module = _delta_info [ " delta_module " ]
delta_name = _delta_info [ " delta_name " ]
setattr ( module , _delta_info [ ' delta_name ' ] , _delta_info [ " delta_module " ] )
new_forward = decorate ( module . forward , _caller , extras = ( module , _delta_info [ ' delta_name ' ] ) , kwsyntax = True ) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module . forward = new_forward . __get__ ( module , type ( module ) ) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
2022-04-14 11:22:41 +08:00
module . _replicate_for_data_parallel = new_replicate_for_data_parallel . __get__ ( module , type ( module ) )
2022-02-14 21:19:03 +08:00
def set_active_state_dict ( self , module : nn . Module ) :
r """ modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
Args :
module ( : obj : ` nn . Module ` ) : The module modified . The modification is in - place .
"""
def _caller ( _org_func , includes , * args , * * kwargs ) :
state_dict = _org_func ( * args , * * kwargs )
keys = list ( state_dict . keys ( ) )
for n in keys :
if n not in includes :
state_dict . pop ( n )
return state_dict
includes = self . trainable_parameters_names ( module ) # use excludes will have trouble when the model have shared weights
if hasattr ( module . state_dict , " __wrapped__ " ) :
2022-03-13 01:21:55 +08:00
raise RuntimeWarning ( " The forward function might have been wrapped by a decorator, is it intended? Do you freeze the parameters twice? " )
2022-02-14 21:19:03 +08:00
module . state_dict = decorate ( module . state_dict , _caller , extras = ( includes , ) , kwsyntax = True ) # decorator.decorate helps preserving the functions metadata (signature, etc.).
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def _load_state_dict_into_backbone ( self , backbone_model : nn . Module = None , state_dict : dict = { } ) :
r """ [NODOC]
"""
if backbone_model is None :
backbone_model = self . backbone_model
self . backbone_model . load_state_dict ( state_dict , strict = False )
def create_config_from_model ( self , ) :
r """ [NODOC] If the delta model was built by directly passing arguments, instead of passing a config object.
create the config of the delta model for saving the delta model .
"""
# common_attributes
config = self . config_class ( )
config_keys = signature ( config . __init__ ) [ 0 ] + signature ( super ( self . config_class , config ) . __init__ ) [ 0 ]
for key in config_keys :
val = getattr ( self , key ) if hasattr ( self , key ) else None
setattr ( config , key , val )
config . delta_type = self . delta_type
self . config = config
2022-04-14 11:22:41 +08:00
2022-03-20 03:09:00 +08:00
def log ( self , module = None , delta_ratio = True , trainable_ratio = True , visualization = True , cuda_memory = True ) :
2022-04-14 11:22:41 +08:00
r """ Log and visualize the result of applying delta.
2022-02-14 21:19:03 +08:00
Possible Options are ` ` trainable_ratio ` ` ,
` ` visualization ` ` , ` ` delta_ratio ` ` .
Args :
delta_ratio ( : obj : ` bool ` , * optional * ) : Whether computing the ratio of parameters in the delta modules .
trainable_ratio ( : obj : ` bool ` , * optional * ) : Whether computing the ratio of trainable parameters .
visualization ( : obj : ` bool ` , * optional * ) : Whether visualize the parameter information of the modified backbone .
"""
if module is None :
module = self . backbone_model
if visualization :
from opendelta import Visualization
Visualization ( module ) . structure_graph ( )
2022-06-08 16:55:02 +08:00
self . get_statistics ( module )
2022-02-14 21:19:03 +08:00
if trainable_ratio :
2022-06-08 16:55:02 +08:00
logger . info ( " Trainable Ratio: {:2f} % " . format ( self . stat [ ' trainable_ratio ' ] * 100 ) )
2022-02-14 21:19:03 +08:00
if delta_ratio :
2022-06-08 16:55:02 +08:00
logger . info ( " Delta Parameter Ratio: {:2f} % " . format ( self . stat [ ' delta_ratio ' ] * 100 ) )
2022-03-20 03:09:00 +08:00
if cuda_memory :
2022-06-08 16:55:02 +08:00
logger . info ( " Static Memory {:.2f} GB, Max Memory {:.2f} GB " . format ( self . stat [ ' cudamem ' ] , self . stat [ ' maxcudamem ' ] ) )
def get_statistics ( self , module = None ) :
r """ Get the statistics of the parameters in the delta modules.
Args :
module ( : obj : ` nn . Module ` , * optional * ) : The module to compute the statistics .
Returns :
: obj : ` dict ` : The statistics of the parameters in the delta modules .
"""
if module is None :
module = self . backbone_model
self . stat = { }
n_trainable = self . num_trainable_parameters ( module )
n_total = self . num_total_parameters ( module )
self . stat [ ' trainable_ratio ' ] = n_trainable / n_total
n_delta = self . num_delta_parameters ( module )
n_total = self . num_total_parameters ( module )
self . stat [ ' delta_ratio ' ] = n_delta / n_total
cudamem = 0
maxcudamem = 0
for device_id in range ( torch . cuda . device_count ( ) ) :
cudamem + = torch . cuda . memory_allocated ( f " cuda: { device_id } " ) / 1024 * * 3
maxcudamem + = torch . cuda . max_memory_allocated ( f " cuda: { device_id } " ) / 1024 * * 3
self . stat [ ' cudamem ' ] = cudamem
self . stat [ ' maxcudamem ' ] = maxcudamem
2022-03-20 03:09:00 +08:00
2022-02-14 21:19:03 +08:00
def num_delta_parameters ( self , module : Optional [ nn . Module ] = None ) :
2022-04-14 11:22:41 +08:00
r """ [NODOC] A small sugar function to get the number of trainable parameter in the backbone model. Often used to
2022-02-14 21:19:03 +08:00
compute the trainable rate .
2022-04-14 11:22:41 +08:00
Args :
2022-02-14 21:19:03 +08:00
module ( : obj : ` nn . Module ` ) : of which module we want to know the number of trainable paramemters .
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
Returns :
2022-04-14 11:22:41 +08:00
: obj : ` List [ nn . Parameter ] `
2022-02-14 21:19:03 +08:00
"""
if module is None :
module = self . backbone_model
pnum_tot = 0
for param in module . parameters ( ) :
if hasattr ( param , " _is_delta " ) :
pnum_tot + = param . numel ( )
return pnum_tot
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
# Two functions for plug and remove the delta model.
2022-03-20 10:48:49 +08:00
def attach ( self , module : Optional [ nn . Module ] = None , reset_state_dict = True ) :
2022-02-14 21:19:03 +08:00
r """ Reattach the delta modules to the backbone. Note that this method can not be used to create new delta modules.
2022-04-14 11:22:41 +08:00
Instead , a : meth : ` DeltaBase . detach ` should precede this method .
2022-02-14 21:19:03 +08:00
Args :
2022-04-14 11:22:41 +08:00
module ( : obj : ` object ` , * optional * , default to : obj : ` None ` ) : The backbone module that we
2022-02-14 21:19:03 +08:00
reattach the deltas to .
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None :
module = self . backbone_model
for name , submodule in module . named_modules ( ) :
if hasattr ( submodule , " _delta_infos " ) :
2022-04-14 11:22:41 +08:00
_delta_infos = getattr ( submodule , " _delta_infos " )
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos :
if _delta_info [ ' delta_belong ' ] is not self :
continue
if _delta_info [ " state " ] == " on " :
continue
if _delta_info [ ' method ' ] == " replace " :
setattr ( submodule , _delta_info [ " child_name " ] , _delta_info [ ' delta_module ' ] )
elif _delta_info [ ' method ' ] == " insert_sequential " :
2022-04-14 11:22:41 +08:00
self . insert_sequential_module ( module = submodule ,
2022-02-14 21:19:03 +08:00
_delta_info = _delta_info )
2022-04-18 23:28:13 +08:00
elif _delta_info [ ' method ' ] == " insert_parallel " :
self . insert_parallel_module ( module = submodule ,
_delta_info = _delta_info )
2022-02-14 21:19:03 +08:00
else :
raise NotImplementedError
2022-04-14 11:22:41 +08:00
_delta_info [ ' state ' ] = " on "
2022-03-20 10:48:49 +08:00
if reset_state_dict :
self . set_active_state_dict ( module )
2022-02-14 21:19:03 +08:00
2022-03-20 10:48:49 +08:00
def detach ( self , module : Optional [ nn . Module ] = None , reset_state_dict = True ) :
2022-02-14 21:19:03 +08:00
r """ Detach the delta module from the backbone. The delta module is not deleted, but temporarily turned off.
Use : meth : ` DeltaBase . attach ` to reattach the delta model to the backbone .
Args :
2022-04-14 11:22:41 +08:00
module ( : obj : ` object ` , * optional * , default to : obj : ` None ` ) : The backbone module that we
2022-02-14 21:19:03 +08:00
detached the deltas from .
"""
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if module is None :
module = self . backbone_model
for name , submodule in module . named_modules ( ) :
if hasattr ( submodule , " _delta_infos " ) :
2022-04-14 11:22:41 +08:00
_delta_infos = getattr ( submodule , " _delta_infos " )
2022-02-14 21:19:03 +08:00
for _delta_info in _delta_infos :
if _delta_info [ ' delta_belong ' ] is not self :
continue
if _delta_info [ " state " ] == " off " :
continue
if _delta_info [ ' method ' ] == " replace " :
setattr ( submodule , _delta_info [ " child_name " ] , _delta_info [ ' org_module ' ] )
elif _delta_info [ ' method ' ] == " insert_sequential " :
if hasattr ( submodule . forward , " __wrapped__ " ) :
submodule . forward = submodule . forward . __wrapped__
delattr ( submodule , _delta_info [ " delta_name " ] )
else :
2022-04-18 23:28:13 +08:00
raise AttributeError ( " submodule {} ' s forward has no attribute __wrapped__. It ' s not a wrapped function. " . format ( name ) )
elif _delta_info [ ' method ' ] == " insert_parallel " :
if hasattr ( submodule . forward , " __wrapped__ " ) :
submodule . forward = submodule . forward . __wrapped__
delattr ( submodule , _delta_info [ " delta_name " ] )
else :
raise AttributeError ( " submodule {} ' s forward has no attribute __wrapped__. It ' s not a wrapped function. " . format ( name ) )
2022-02-14 21:19:03 +08:00
else :
raise NotImplementedError
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
_delta_info [ ' state ' ] = " off "
2022-03-20 10:48:49 +08:00
if reset_state_dict :
try :
module . state_dict = module . state_dict . __wrapped__
except AttributeError :
pass
2022-04-14 11:22:41 +08:00