2022-02-14 21:19:03 +08:00
from copy import deepcopy
from typing import Any , Dict , OrderedDict
2022-11-10 18:24:50 +08:00
from bigmodelvis import Visualization
2022-02-14 21:19:03 +08:00
import torch . nn as nn
from opendelta . utils . logging import get_logger
import importlib
from opendelta . delta_configs import BaseDeltaConfig
2022-10-14 23:15:38 +08:00
from opendelta . basemodel import DeltaBase
2022-02-14 21:19:03 +08:00
logger = get_logger ( __name__ )
DELTA_CONFIG_MAPPING = {
2022-07-01 22:23:02 +08:00
" lora " : " LoraConfig " ,
2022-02-14 21:19:03 +08:00
" low_rank_adapter " : " LowRankAdapterConfig " ,
" bitfit " : " BitFitConfig " ,
" adapter " : " AdapterConfig " ,
" compacter " : " CompacterConfig " ,
" prefix " : " PrefixConfig " ,
" soft_prompt " : " SoftPromptConfig " ,
2022-02-24 23:21:31 +08:00
" parallel_adapter " : " ParallelAdapterConfig " ,
2022-02-14 21:19:03 +08:00
}
DELTA_MODEL_MAPPING = {
" lora " : " LoraModel " ,
" low_rank_adapter " : " LowRankAdapterModel " ,
" bitfit " : " BitFitModel " ,
" adapter " : " AdapterModel " ,
" compacter " : " CompacterModel " ,
" prefix " : " PrefixModel " ,
" soft_prompt " : " SoftPromptModel " ,
2022-02-24 23:21:31 +08:00
" parallel_adapter " : " ParallelAdapterModel " ,
2022-02-14 21:19:03 +08:00
}
class _LazyConfigMapping ( OrderedDict ) :
"""
A dictionary that lazily load its values when they are requested .
"""
def __init__ ( self , mapping ) :
self . _mapping = mapping
self . _extra_content = { }
self . _modules = { }
def __getitem__ ( self , key ) :
if key in self . _extra_content :
return self . _extra_content [ key ]
if key not in self . _mapping :
raise KeyError ( key )
value = self . _mapping [ key ]
module_name = key #model_type_to_module_name(key)
# if module_name not in self._modules:
self . _modules [ module_name ] = importlib . import_module ( f " . { module_name } " , " opendelta.delta_models " )
return getattr ( self . _modules [ module_name ] , value )
def keys ( self ) :
return list ( self . _mapping . keys ( ) ) + list ( self . _extra_content . keys ( ) )
def values ( self ) :
return [ self [ k ] for k in self . _mapping . keys ( ) ] + list ( self . _extra_content . values ( ) )
def items ( self ) :
return [ ( k , self [ k ] ) for k in self . _mapping . keys ( ) ] + list ( self . _extra_content . items ( ) )
def __iter__ ( self ) :
return iter ( list ( self . _mapping . keys ( ) ) + list ( self . _extra_content . keys ( ) ) )
def __contains__ ( self , item ) :
return item in self . _mapping or item in self . _extra_content
def register ( self , key , value ) :
"""
Register a new configuration in this mapping .
"""
if key in self . _mapping . keys ( ) :
raise ValueError ( f " ' { key } ' is already used by a Transformers config, pick another name. " )
self . _extra_content [ key ] = value
LAZY_CONFIG_MAPPING = _LazyConfigMapping ( DELTA_CONFIG_MAPPING )
class AutoDeltaConfig :
r """
This is a generic configuration class that will be instantiated as one of the configuration classes of the library
2022-10-14 23:15:38 +08:00
when created with the : meth : ` ~ AutoDeltaConfig . from_finetuned ` or : meth : ` ~ AutoDeltaConfig . from_dict ` class method .
2022-02-14 21:19:03 +08:00
This class cannot be instantiated directly using ` ` __init__ ( ) ` ` ( throws an error ) .
"""
2022-10-12 01:36:38 +08:00
def __init__ ( self , * args , * * kwargs ) :
raise AttributeError (
f " { self . __class__ . __name__ } is designed to be instantiated using \n \t (1) ` { self . __class__ . __name__ } .from_finetuned(finetuned_model_name_or_path)` \n or \t (2) ` { self . __class__ . __name__ } .from_dict(config_dict, **kwargs)` "
2022-02-14 21:19:03 +08:00
)
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
def from_dict ( cls , config_dict : Dict [ str , Any ] , * * kwargs ) :
2022-07-01 22:23:02 +08:00
r """ Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
2022-02-14 21:19:03 +08:00
: obj : ` delta_type ` .
Args :
config_dict ( : obj : ` dict ` ) : The dict of configs of delta model .
2022-07-01 22:23:02 +08:00
kwargs : Other keyword argument pass to initialize the config .
2022-02-14 21:19:03 +08:00
2022-10-14 23:15:38 +08:00
Examples :
. . code - block : : python
config = AutoDeltaConfig . from_dict ( { " delta_type " : " lora " } ) # This will load the dault lora config.
config = AutoDeltaConfig . from_dict ( { " delta_type " : " lora " , " lora_r " : 5 } ) # Will load the default lora config, with lora_r = 5
2022-02-14 21:19:03 +08:00
"""
config_dict = deepcopy ( config_dict )
delta_type = config_dict . pop ( " delta_type " , None )
if delta_type is None :
raise RuntimeError ( " Do not specify a delta type, cannot load the default config " )
config_class = LAZY_CONFIG_MAPPING [ delta_type ]
return config_class . from_dict ( config_dict , * * kwargs )
@classmethod
2022-07-03 10:10:18 +08:00
def from_finetuned ( cls , finetuned_delta_path , * * kwargs ) :
2022-02-14 21:19:03 +08:00
r """
Instantiate one of the configuration classes of the library from a finetuned delta model configuration .
The configuration class to instantiate is selected based on the ` ` delta_type ` ` property of the config object that
is loaded .
Parameters :
2022-10-14 23:15:38 +08:00
finetuned_delta_path ( : obj : ` str ` or : obj : ` os . PathLike ` , * optional * ) : Can be either :
- A string , the model id of a finetuned delta model configuration hosted inside a model repo on huggingface . co . Valid model ids can be located at the root - level , like ` ` Davin / lora ` ` , or namespaced under a user or organization name , like ` ` DeltaHub / lora_t5 - base_mrpc ` ` .
- A path to a * directory * containing a configuration file saved using the : py : meth : ` ~ opendelta . basemodel . DeltaBase . save_finetuned ` method , e . g . , ` ` . / my_model_directory / ` ` .
- A path or url to a saved configuration JSON * file * , e . g . , ` ` . / my_model_directory / configuration . json ` ` .
2022-02-14 21:19:03 +08:00
cache_dir ( : obj : ` str ` or : obj : ` os . PathLike ` , * optional * ) :
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used .
2022-10-14 23:15:38 +08:00
2022-02-14 21:19:03 +08:00
Examples :
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
. . code - block : : python
from transformers import AutoConfig
2022-10-14 23:15:38 +08:00
delta_config = AutoDeltaConfig . from_finetuned ( " thunlp/FactQA_T5-large_Adapter " )
2022-02-14 21:19:03 +08:00
"""
2022-07-03 10:10:18 +08:00
config_dict , kwargs = BaseDeltaConfig . get_config_dict ( finetuned_delta_path , * * kwargs )
2022-02-14 21:19:03 +08:00
if " delta_type " in config_dict :
config_class = LAZY_CONFIG_MAPPING [ config_dict [ " delta_type " ] ]
return config_class . from_dict ( config_dict , * * kwargs )
else :
# Fallback: use pattern matching on the string.
for pattern , config_class in LAZY_CONFIG_MAPPING . items ( ) :
2022-07-03 10:10:18 +08:00
if pattern in str ( finetuned_delta_path ) :
2022-02-14 21:19:03 +08:00
return config_class . from_dict ( config_dict , * * kwargs )
raise ValueError (
2022-07-03 10:10:18 +08:00
f " Unrecognized model in { finetuned_delta_path } . "
2022-02-14 21:19:03 +08:00
f " Should have a `delta_type` key in the loaded config, or contain one of the following strings "
f " in its name: { ' , ' . join ( LAZY_CONFIG_MAPPING . keys ( ) ) } "
)
2022-07-01 22:23:02 +08:00
### AutoModels below
2022-02-14 21:19:03 +08:00
class _LazyAutoMapping ( OrderedDict ) :
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args :
- config_mapping : The map model type to config class
- model_mapping : The map model type to model ( or tokenizer ) class
"""
def __init__ ( self , config_mapping , model_mapping ) :
self . _config_mapping = config_mapping
self . _reverse_config_mapping = { v : k for k , v in config_mapping . items ( ) }
self . _model_mapping = model_mapping
self . _extra_content = { }
self . _modules = { }
def __getitem__ ( self , key ) :
if key in self . _extra_content :
return self . _extra_content [ key ]
model_type = self . _reverse_config_mapping [ key . __name__ ]
if model_type not in self . _model_mapping :
raise KeyError ( key )
model_name = self . _model_mapping [ model_type ]
return self . _load_attr_from_module ( model_type , model_name )
def _load_attr_from_module ( self , model_type , attr ) :
if model_type not in self . _modules :
self . _modules [ model_type ] = importlib . import_module ( f " . { model_type } " , " opendelta.delta_models " )
return getattribute_from_module ( self . _modules [ model_type ] , attr )
def keys ( self ) :
mapping_keys = [
self . _load_attr_from_module ( key , name )
for key , name in self . _config_mapping . items ( )
if key in self . _model_mapping . keys ( )
]
return mapping_keys + list ( self . _extra_content . keys ( ) )
def get ( self , key , default ) :
try :
return self . __getitem__ ( key )
except KeyError :
return default
def __bool__ ( self ) :
return bool ( self . keys ( ) )
def values ( self ) :
mapping_values = [
self . _load_attr_from_module ( key , name )
for key , name in self . _model_mapping . items ( )
if key in self . _config_mapping . keys ( )
]
return mapping_values + list ( self . _extra_content . values ( ) )
def items ( self ) :
mapping_items = [
(
self . _load_attr_from_module ( key , self . _config_mapping [ key ] ) ,
self . _load_attr_from_module ( key , self . _model_mapping [ key ] ) ,
)
for key in self . _model_mapping . keys ( )
if key in self . _config_mapping . keys ( )
]
return mapping_items + list ( self . _extra_content . items ( ) )
def __iter__ ( self ) :
return iter ( self . keys ( ) )
def __contains__ ( self , item ) :
if item in self . _extra_content :
return True
if not hasattr ( item , " __name__ " ) or item . __name__ not in self . _reverse_config_mapping :
return False
model_type = self . _reverse_config_mapping [ item . __name__ ]
return model_type in self . _model_mapping
def register ( self , key , value ) :
"""
Register a new model in this mapping .
"""
if hasattr ( key , " __name__ " ) and key . __name__ in self . _reverse_config_mapping :
model_type = self . _reverse_config_mapping [ key . __name__ ]
if model_type in self . _model_mapping . keys ( ) :
raise ValueError ( f " ' { key } ' is already used by a Transformers model. " )
self . _extra_content [ key ] = value
LAZY_DELTA_MAPPING = _LazyAutoMapping ( DELTA_CONFIG_MAPPING , DELTA_MODEL_MAPPING )
def get_values ( model_mapping ) :
result = [ ]
for model in model_mapping . values ( ) :
if isinstance ( model , ( list , tuple ) ) :
result + = list ( model )
else :
result . append ( model )
return result
def getattribute_from_module ( module , attr ) :
if attr is None :
return None
if isinstance ( attr , tuple ) :
return tuple ( getattribute_from_module ( module , a ) for a in attr )
if hasattr ( module , attr ) :
return getattr ( module , attr )
# Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
# object at the top level.
transformers_module = importlib . import_module ( " transformers " )
return getattribute_from_module ( transformers_module , attr )
class AutoDeltaModel :
r """
"""
_delta_model_mapping = LAZY_DELTA_MAPPING
def __init__ ( self , * args , * * kwargs ) :
2022-10-12 01:36:38 +08:00
# raise EnvironmentError(
# f"{self.__class__.__name__} is designed to be instantiated "
# f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
# f"`{self.__class__.__name__}.from_config(config)` methods."
# )
raise AttributeError (
f " { self . __class__ . __name__ } is designed to be instantiated using \n \t (1) ` { self . __class__ . __name__ } .from_finetuned(finetuned_delta_path, backbone_model, *model_args, **kwargs)` \n or \t (2) ` { self . __class__ . __name__ } .from_config(delta_config, backbone_model, **kwargs)` "
2022-02-14 21:19:03 +08:00
)
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
2022-10-14 23:15:38 +08:00
def from_config ( cls , config , backbone_model , * * kwargs ) - > DeltaBase :
2022-02-14 21:19:03 +08:00
r """ Automatically instantiates a delta model based on the :obj:`config`. The delta model correspond to the delta
2022-07-01 22:23:02 +08:00
: obj : ` config ` will be loaded and initialized using the arguments in : obj : ` config ` .
2022-02-14 21:19:03 +08:00
. . note : :
2022-07-01 22:23:02 +08:00
Only using : meth : ` from_config ` method will not load the finetuned weight file ( e . g . , pytorch_model . bin ) .
Please use from_finetuned directly .
2022-02-14 21:19:03 +08:00
Args :
config ( : obj : ` BaseDeltaConfig ` ) :
backbone_model ( : obj : ` nn . Module ` ) :
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
Examples :
. . code - block : : python
config = AutoDeltaConfig . from_finetuned ( " DeltaHub/lora_t5-base_mrpc " )
delta_model = AutoDeltaModel . from_config ( config , backbone_model )
"""
if type ( config ) in cls . _delta_model_mapping . keys ( ) :
model_class = cls . _delta_model_mapping [ type ( config ) ]
return model_class . from_config ( config , backbone_model , * * kwargs )
raise ValueError (
f " Unrecognized configuration class { config . __class__ } for this kind of AutoModel: { cls . __name__ } . \n "
f " Model type should be one of { ' , ' . join ( c . __name__ for c in cls . _delta_model_mapping . keys ( ) ) } . "
)
@classmethod
2022-10-14 23:15:38 +08:00
def from_finetuned ( cls , finetuned_delta_path , backbone_model , * model_args , * * kwargs ) - > DeltaBase :
2022-07-01 22:23:02 +08:00
r """ Automatically instantiated a delta model and load the finetuned checkpoints based on the
2022-07-03 10:10:18 +08:00
: obj : ` finetuned_delta_path ` , which can either be a string pointing to a local path or a url pointint to
2022-07-01 22:23:02 +08:00
the delta hub . It will check the hash after loading the delta model to see whether the correct backbone and
delta checkpoint are used .
2022-02-14 21:19:03 +08:00
Args :
2022-10-14 23:15:38 +08:00
finetuned_delta_path ( : obj : ` str ` or : obj : ` os . PathLike ` , * optional * ) : Can be either :
- A string , the model name of a finetuned delta model configuration hosted inside a model repo on ` Delta Center < https : / / www . openbmb . org / toolKits / deltacenter > ` _ , like ` ` thunlp / FactQA_T5 - large_Adapter ` ` .
- A path to a directory containing a configuration file saved using the : meth : ` ~ opendelta . utils . saving_loading_utils . SaveLoadMixin . save_finetuned ` method , e . g . , ` ` . / my_model_directory / ` ` .
- A path or url to a saved configuration JSON * file * , e . g . , ` ` . / my_model_directory / configuration . json ` ` . The last two option are not tested but inherited from huggingface .
2022-02-14 21:19:03 +08:00
backbone_model ( : obj : ` nn . Module ` ) : The backbone model to be modified .
2022-10-14 23:15:38 +08:00
model_args : Other argument for initialize the model . See : ` DeltaBase . from_finetuned ` for details .
kwargs : Other kwargs that will be passed into DeltaBase . from_finetuned . See ` DeltaBase . from_finetuned ` for details .
2022-02-14 21:19:03 +08:00
Example :
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
. . code - block : : python
2022-10-14 23:15:38 +08:00
delta_model = AutoDeltaModel . from_finetuned ( " thunlp/FactQA_T5-large_Adapter " , backbone_model = 5 )
2022-02-14 21:19:03 +08:00
"""
2022-07-03 10:10:18 +08:00
delta_config = kwargs . pop ( " delta_config " , None )
2022-02-14 21:19:03 +08:00
2022-07-03 10:10:18 +08:00
if not isinstance ( delta_config , BaseDeltaConfig ) :
delta_config , kwargs = AutoDeltaConfig . from_finetuned (
finetuned_delta_path , return_unused_kwargs = True , * * kwargs
2022-02-14 21:19:03 +08:00
)
2022-07-03 10:10:18 +08:00
if type ( delta_config ) in cls . _delta_model_mapping . keys ( ) :
model_class = cls . _delta_model_mapping [ type ( delta_config ) ]
return model_class . from_finetuned ( finetuned_delta_path , backbone_model , * model_args , delta_config = delta_config , * * kwargs )
2022-02-14 21:19:03 +08:00
raise ValueError (
f " Unrecognized configuration class { config . __class__ } for this kind of AutoModel: { cls . __name__ } . \n "
f " Model type should be one of { ' , ' . join ( c . __name__ for c in cls . _model_mapping . keys ( ) ) } . "
)
2022-07-01 22:23:02 +08:00
2022-02-14 21:19:03 +08:00
if __name__ == " __main__ " :
config = AutoDeltaConfig . from_dict ( { " delta_type " : " lora " , " lora_r " : 7 } )
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification . from_pretrained ( " ../../plm_cache/roberta-base/ " , num_labels = 2 )
# from IPython import embed
delta_model = AutoDeltaModel . from_config ( config , model )
delta_model . freeze_module ( exclude = [ ' deltas ' , ' classifier ' ] , set_state_dict = True )
# delta_model.save_finetuned("autodelta_try", push_to_hub=True, private=True)
delta_model = AutoDeltaModel . from_finetuned ( " ShengdingHu/autodelta_try " , model , use_auth_token = True )