small update in save_and_load

This commit is contained in:
shengdinghu 2022-07-06 14:00:58 +00:00
parent ed2bd8c50f
commit b9a0f7cf89
12 changed files with 106 additions and 53 deletions

Binary file not shown.

Binary file not shown.

BIN
dist/opendelta-0.2.1-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.1.tar.gz vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.2-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.2.2.tar.gz vendored Normal file

Binary file not shown.

View File

@ -1,5 +1,5 @@
__version__ = "0.2.0"
__version__ = "0.2.2"
class GlobalSetting:
def __init__(self):

View File

@ -193,6 +193,7 @@ class BaseDeltaConfig:
config_dict.pop(config_key)
unused_config_keys.append(config_key)
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
config = cls(**config_dict)

View File

@ -22,10 +22,18 @@ class InterFaceMixin:
self._reverse_axis_order = np.argsort(self._axis_order).tolist()
def _transpose(self, tensor):
if tensor.dim() == 3:
return tensor.permute(*self._axis_order)
else:
return tensor
def _reverse_transpose(self, tensor):
if tensor.dim() == 3:
return tensor.permute(*self._reverse_axis_order).contiguous()
else:
return tensor
def _convert_data_type(self, tensor):
self._data_type_record = tensor.dtype
@ -37,6 +45,8 @@ class InterFaceMixin:
class AdapterLayer(nn.Module, InterFaceMixin):
r"""A layer of adapter tuning module.
"""

View File

@ -91,10 +91,10 @@ class SaveLoadMixin:
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_dc: bool = True,
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
center_args_pool: Optional[dict] = None,
list_tags: Optional[List] = None,
dict_tags: Optional[Dict] = None,
center_args: Optional[Union[DeltaCenterArguments, dict]] = dict(),
center_args_pool: Optional[dict] = dict(),
list_tags: Optional[List] = list(),
dict_tags: Optional[Dict] = dict(),
delay_push: bool = False,
):
r"""
@ -106,10 +106,11 @@ class SaveLoadMixin:
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
which is a subdirectory of the current working directory.
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
model file.
model file. if ``False``, only the state dict will be saved.
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
state_dict is loaded from the backbone model's trainable parameters.
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
state_dict_only: (optional) if ``True``, only the state_dict will be saved.
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
This includes:
- creating a configuration file for the model
@ -131,6 +132,7 @@ class SaveLoadMixin:
self.create_config_from_model()
self.add_configs_when_saving()
if push_to_dc:
final_center_args = self.create_delta_center_args(center_args=center_args,
center_args_pool=center_args_pool)
@ -140,6 +142,8 @@ class SaveLoadMixin:
return
os.makedirs(save_directory, exist_ok=True)
if push_to_dc:
save_directory = os.path.join(save_directory, final_center_args.name)
os.makedirs(save_directory, exist_ok=True)
@ -149,13 +153,13 @@ class SaveLoadMixin:
if state_dict is None:
state_dict = model_to_save.state_dict()
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
# Save the config
if save_config:
self.config.save_finetuned(save_directory)
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
save_function(state_dict, output_model_file)
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
)
@ -164,24 +168,36 @@ class SaveLoadMixin:
logger.info("Creating yaml file for delta center")
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
if not delay_push:
OssClient.upload(base_dir=save_directory)
else:
logger.info("Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
logger.info(f"Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
+"*"*30)
# get absolute path of saved_directory,
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
def create_yml(self, save_dir, config, list_tags=list(), dict_tags=dict()):
f = open("{}/config.yml".format(save_dir), 'w')
config_dict = vars(config)
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {}
config_dict['list_tags'] = list_tags if list_tags is not None else []
config_dict['dict_tags'] = dict_tags
config_dict['list_tags'] = list_tags
yaml.safe_dump(config_dict, f)
f.close()
def load_checkpoint(self, path, load_func=torch.load, backbone_model=None):
r"""Simple method for loading only the checkpoint
"""
if backbone_model is None:
backbone_model = self.backbone_model
self.backbone_model.load_state_dict(load_func(f"{path}/{WEIGHTS_NAME}"), strict=False)
def save_checkpoint(self, path, save_func=torch.save, backbone_model=None):
r"""Simple method for saving only the checkpoint"""
if backbone_model is None:
backbone_model = self.backbone_model
save_func(backbone_model.state_dict(), f"{path}/{WEIGHTS_NAME}")
@classmethod
def from_finetuned(cls,
finetuned_delta_path: Optional[Union[str, os.PathLike]],

View File

@ -199,6 +199,12 @@ distilbert_mapping = {
}
}
MAPPINGERROR_MSG = "We haven't provide common structure mapping for this backbone model." + \
" If it is a common enough PLM, please check whether it is wrapped by other wrapper model, e.g., XXXForSequenceClassification." +\
"Please manually add the "+\
"delta models by speicifying 'modified_modules' based on the visualization of model structure. Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail."
def transform(org_key, mapping, strict=True, warning=False, verbose=False):
chain = org_key.split(".")
@ -266,7 +272,7 @@ def mapping_for_ConditionalGeneration(mapping, type):
if type == "t5":
mapping["lm_head"] = {"__name__":"lm_head.proj"}
else:
raise NotImplementedError
raise NotImplementedError(MAPPINGERROR_MSG.format())
return mapping
class _LazyLoading(OrderedDict):
@ -276,7 +282,7 @@ class _LazyLoading(OrderedDict):
def __getitem__(self, key):
if key not in self._mapping_string:
raise KeyError(key)
raise KeyError(MAPPINGERROR_MSG)
value = self._mapping_string[key]
self._mapping[key] = eval(value)
return self._mapping[key]
@ -289,6 +295,7 @@ class _LazyLoading(OrderedDict):
return item in self._mapping_string
class CommonStructureMap(object):
r""" A lazy loading structure map.
"""
@ -296,7 +303,6 @@ class CommonStructureMap(object):
"RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""",
"RobertaForMaskedLM": "roberta_mapping",
"BertForMaskedLM": "bert_mapping",
"BertForSequenceClassification": """mapping_for_SequenceClassification(bert_mapping, "bert")""",
"T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""",
"DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")"""
})
@ -315,8 +321,17 @@ class CommonStructureMap(object):
"""
backbone_class = type(backbone_model).__name__
if backbone_class not in cls.Mappings:
raise KeyError(backbone_class)
raise KeyError(MAPPINGERROR_MSG)
try:
mapping = cls.Mappings[backbone_class]
except KeyError:
logger.error(MAPPINGERROR_MSG)
exit(-1)
if visualize:
logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")
vis = Visualization(backbone_model)
@ -346,4 +361,3 @@ if __name__ == "__main__":
for name, _ in plm.named_modules():
transform(name, t5_mapping, strict=True, warning=False)

View File

@ -3,22 +3,34 @@ import setuptools
import os
import os
def get_requirements(path):
ret = []
# path = "/home/ubuntu/OpenDelta_beta/OpenDelta/"
with open(os.path.join(path, "opendelta.egg-info/requires.txt"), encoding="utf-8") as freq:
for line in freq.readlines():
ret.append( line.strip() )
requires = """torch>=1.8.0
transformers>=4.10.0
datasets==1.17.0
sentencepiece>=0.1.96
tqdm>=4.62.2
# loralib
decorator
rich
web.py
gitpython
delta_center_client
"""
def get_requirements():
ret = [x for x in requires.split("\n") if len(x)>0]
print("requirements:", ret)
return ret
path = os.path.dirname(os.path.abspath(__file__))
requires = get_requirements(path)
# path = os.path.dirname(os.path.abspath(__file__))
# requires = get_requirements(path)
with open('README.md', 'r') as f:
setuptools.setup(
name = 'opendelta',
version = "0.2.0",
version = "0.2.2",
description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@ -28,7 +40,7 @@ with open('README.md', 'r') as f:
url="https://github.com/thunlp/OpenDelta",
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
python_requires=">=3.6.0",
install_requires=get_requirements(path),
install_requires=get_requirements(),
package_dir={'opendelta':'opendelta'},
package_data= {
'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'],