small update in save_and_load

This commit is contained in:
shengdinghu 2022-07-06 14:00:58 +00:00
parent ed2bd8c50f
commit b9a0f7cf89
12 changed files with 106 additions and 53 deletions

Binary file not shown.

Binary file not shown.

BIN
dist/opendelta-0.2.1-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.1.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.2-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.2.tar.gz vendored Normal file

Binary file not shown.

View File

@ -1,5 +1,5 @@
__version__ = "0.2.0" __version__ = "0.2.2"
class GlobalSetting: class GlobalSetting:
def __init__(self): def __init__(self):

View File

@ -193,6 +193,7 @@ class BaseDeltaConfig:
config_dict.pop(config_key) config_dict.pop(config_key)
unused_config_keys.append(config_key) unused_config_keys.append(config_key)
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}") logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
config = cls(**config_dict) config = cls(**config_dict)

View File

@ -22,10 +22,18 @@ class InterFaceMixin:
self._reverse_axis_order = np.argsort(self._axis_order).tolist() self._reverse_axis_order = np.argsort(self._axis_order).tolist()
def _transpose(self, tensor): def _transpose(self, tensor):
if tensor.dim() == 3:
return tensor.permute(*self._axis_order) return tensor.permute(*self._axis_order)
else:
return tensor
def _reverse_transpose(self, tensor): def _reverse_transpose(self, tensor):
if tensor.dim() == 3:
return tensor.permute(*self._reverse_axis_order).contiguous() return tensor.permute(*self._reverse_axis_order).contiguous()
else:
return tensor
def _convert_data_type(self, tensor): def _convert_data_type(self, tensor):
self._data_type_record = tensor.dtype self._data_type_record = tensor.dtype
@ -37,6 +45,8 @@ class InterFaceMixin:
class AdapterLayer(nn.Module, InterFaceMixin): class AdapterLayer(nn.Module, InterFaceMixin):
r"""A layer of adapter tuning module. r"""A layer of adapter tuning module.
""" """

View File

@ -91,10 +91,10 @@ class SaveLoadMixin:
state_dict: Optional[dict] = None, state_dict: Optional[dict] = None,
save_function: Callable = torch.save, save_function: Callable = torch.save,
push_to_dc: bool = True, push_to_dc: bool = True,
center_args: Optional[Union[DeltaCenterArguments, dict]] = None, center_args: Optional[Union[DeltaCenterArguments, dict]] = dict(),
center_args_pool: Optional[dict] = None, center_args_pool: Optional[dict] = dict(),
list_tags: Optional[List] = None, list_tags: Optional[List] = list(),
dict_tags: Optional[Dict] = None, dict_tags: Optional[Dict] = dict(),
delay_push: bool = False, delay_push: bool = False,
): ):
r""" r"""
@ -106,10 +106,11 @@ class SaveLoadMixin:
If not specified, the model will be saved in the directory ``./delta_checkpoints/``, If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
which is a subdirectory of the current working directory. which is a subdirectory of the current working directory.
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
model file. model file. if ``False``, only the state dict will be saved.
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
state_dict is loaded from the backbone model's trainable parameters. state_dict is loaded from the backbone model's trainable parameters.
save_function: (optional) the function used to save the model. Defaults to ``torch.save``. save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
state_dict_only: (optional) if ``True``, only the state_dict will be saved.
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter. push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
This includes: This includes:
- creating a configuration file for the model - creating a configuration file for the model
@ -131,6 +132,7 @@ class SaveLoadMixin:
self.create_config_from_model() self.create_config_from_model()
self.add_configs_when_saving() self.add_configs_when_saving()
if push_to_dc:
final_center_args = self.create_delta_center_args(center_args=center_args, final_center_args = self.create_delta_center_args(center_args=center_args,
center_args_pool=center_args_pool) center_args_pool=center_args_pool)
@ -140,6 +142,8 @@ class SaveLoadMixin:
return return
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
if push_to_dc:
save_directory = os.path.join(save_directory, final_center_args.name) save_directory = os.path.join(save_directory, final_center_args.name)
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
@ -149,13 +153,13 @@ class SaveLoadMixin:
if state_dict is None: if state_dict is None:
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
# Save the config # Save the config
if save_config: if save_config:
self.config.save_finetuned(save_directory) self.config.save_finetuned(save_directory)
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}" logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
) )
@ -164,24 +168,36 @@ class SaveLoadMixin:
logger.info("Creating yaml file for delta center") logger.info("Creating yaml file for delta center")
self.create_yml(save_directory, final_center_args, list_tags, dict_tags) self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
if not delay_push: if not delay_push:
OssClient.upload(base_dir=save_directory) OssClient.upload(base_dir=save_directory)
else: else:
logger.info("Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n" logger.info(f"Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
+"*"*30) +"*"*30)
# get absolute path of saved_directory,
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
def create_yml(self, save_dir, config, list_tags=list(), dict_tags=dict()):
f = open("{}/config.yml".format(save_dir), 'w') f = open("{}/config.yml".format(save_dir), 'w')
config_dict = vars(config) config_dict = vars(config)
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {} config_dict['dict_tags'] = dict_tags
config_dict['list_tags'] = list_tags if list_tags is not None else [] config_dict['list_tags'] = list_tags
yaml.safe_dump(config_dict, f) yaml.safe_dump(config_dict, f)
f.close() f.close()
def load_checkpoint(self, path, load_func=torch.load, backbone_model=None):
r"""Simple method for loading only the checkpoint
"""
if backbone_model is None:
backbone_model = self.backbone_model
self.backbone_model.load_state_dict(load_func(f"{path}/{WEIGHTS_NAME}"), strict=False)
def save_checkpoint(self, path, save_func=torch.save, backbone_model=None):
r"""Simple method for saving only the checkpoint"""
if backbone_model is None:
backbone_model = self.backbone_model
save_func(backbone_model.state_dict(), f"{path}/{WEIGHTS_NAME}")
@classmethod @classmethod
def from_finetuned(cls, def from_finetuned(cls,
finetuned_delta_path: Optional[Union[str, os.PathLike]], finetuned_delta_path: Optional[Union[str, os.PathLike]],

View File

@ -199,6 +199,12 @@ distilbert_mapping = {
} }
} }
MAPPINGERROR_MSG = "We haven't provide common structure mapping for this backbone model." + \
" If it is a common enough PLM, please check whether it is wrapped by other wrapper model, e.g., XXXForSequenceClassification." +\
"Please manually add the "+\
"delta models by speicifying 'modified_modules' based on the visualization of model structure. Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail."
def transform(org_key, mapping, strict=True, warning=False, verbose=False): def transform(org_key, mapping, strict=True, warning=False, verbose=False):
chain = org_key.split(".") chain = org_key.split(".")
@ -266,7 +272,7 @@ def mapping_for_ConditionalGeneration(mapping, type):
if type == "t5": if type == "t5":
mapping["lm_head"] = {"__name__":"lm_head.proj"} mapping["lm_head"] = {"__name__":"lm_head.proj"}
else: else:
raise NotImplementedError raise NotImplementedError(MAPPINGERROR_MSG.format())
return mapping return mapping
class _LazyLoading(OrderedDict): class _LazyLoading(OrderedDict):
@ -276,7 +282,7 @@ class _LazyLoading(OrderedDict):
def __getitem__(self, key): def __getitem__(self, key):
if key not in self._mapping_string: if key not in self._mapping_string:
raise KeyError(key) raise KeyError(MAPPINGERROR_MSG)
value = self._mapping_string[key] value = self._mapping_string[key]
self._mapping[key] = eval(value) self._mapping[key] = eval(value)
return self._mapping[key] return self._mapping[key]
@ -289,6 +295,7 @@ class _LazyLoading(OrderedDict):
return item in self._mapping_string return item in self._mapping_string
class CommonStructureMap(object): class CommonStructureMap(object):
r""" A lazy loading structure map. r""" A lazy loading structure map.
""" """
@ -296,7 +303,6 @@ class CommonStructureMap(object):
"RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""", "RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""",
"RobertaForMaskedLM": "roberta_mapping", "RobertaForMaskedLM": "roberta_mapping",
"BertForMaskedLM": "bert_mapping", "BertForMaskedLM": "bert_mapping",
"BertForSequenceClassification": """mapping_for_SequenceClassification(bert_mapping, "bert")""",
"T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""", "T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""",
"DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")""" "DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")"""
}) })
@ -315,8 +321,17 @@ class CommonStructureMap(object):
""" """
backbone_class = type(backbone_model).__name__ backbone_class = type(backbone_model).__name__
if backbone_class not in cls.Mappings: if backbone_class not in cls.Mappings:
raise KeyError(backbone_class) raise KeyError(MAPPINGERROR_MSG)
try:
mapping = cls.Mappings[backbone_class] mapping = cls.Mappings[backbone_class]
except KeyError:
logger.error(MAPPINGERROR_MSG)
exit(-1)
if visualize: if visualize:
logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.") logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")
vis = Visualization(backbone_model) vis = Visualization(backbone_model)
@ -346,4 +361,3 @@ if __name__ == "__main__":
for name, _ in plm.named_modules(): for name, _ in plm.named_modules():
transform(name, t5_mapping, strict=True, warning=False) transform(name, t5_mapping, strict=True, warning=False)

View File

@ -3,22 +3,34 @@ import setuptools
import os import os
import os import os
def get_requirements(path):
ret = [] requires = """torch>=1.8.0
# path = "/home/ubuntu/OpenDelta_beta/OpenDelta/" transformers>=4.10.0
with open(os.path.join(path, "opendelta.egg-info/requires.txt"), encoding="utf-8") as freq: datasets==1.17.0
for line in freq.readlines(): sentencepiece>=0.1.96
ret.append( line.strip() ) tqdm>=4.62.2
# loralib
decorator
rich
web.py
gitpython
delta_center_client
"""
def get_requirements():
ret = [x for x in requires.split("\n") if len(x)>0]
print("requirements:", ret)
return ret return ret
path = os.path.dirname(os.path.abspath(__file__))
requires = get_requirements(path) # path = os.path.dirname(os.path.abspath(__file__))
# requires = get_requirements(path)
with open('README.md', 'r') as f: with open('README.md', 'r') as f:
setuptools.setup( setuptools.setup(
name = 'opendelta', name = 'opendelta',
version = "0.2.0", version = "0.2.2",
description = "An open source framework for delta learning (parameter efficient learning).", description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
@ -28,7 +40,7 @@ with open('README.md', 'r') as f:
url="https://github.com/thunlp/OpenDelta", url="https://github.com/thunlp/OpenDelta",
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'], keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
python_requires=">=3.6.0", python_requires=">=3.6.0",
install_requires=get_requirements(path), install_requires=get_requirements(),
package_dir={'opendelta':'opendelta'}, package_dir={'opendelta':'opendelta'},
package_data= { package_data= {
'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'], 'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'],