small update in save_and_load
This commit is contained in:
parent
ed2bd8c50f
commit
b9a0f7cf89
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,5 +1,5 @@
|
|||
|
||||
__version__ = "0.2.0"
|
||||
__version__ = "0.2.2"
|
||||
|
||||
class GlobalSetting:
|
||||
def __init__(self):
|
||||
|
|
|
@ -193,6 +193,7 @@ class BaseDeltaConfig:
|
|||
config_dict.pop(config_key)
|
||||
unused_config_keys.append(config_key)
|
||||
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
|
||||
|
|
|
@ -22,10 +22,18 @@ class InterFaceMixin:
|
|||
self._reverse_axis_order = np.argsort(self._axis_order).tolist()
|
||||
|
||||
def _transpose(self, tensor):
|
||||
return tensor.permute(*self._axis_order)
|
||||
if tensor.dim() == 3:
|
||||
return tensor.permute(*self._axis_order)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
|
||||
def _reverse_transpose(self, tensor):
|
||||
return tensor.permute(*self._reverse_axis_order).contiguous()
|
||||
if tensor.dim() == 3:
|
||||
return tensor.permute(*self._reverse_axis_order).contiguous()
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def _convert_data_type(self, tensor):
|
||||
self._data_type_record = tensor.dtype
|
||||
|
@ -37,6 +45,8 @@ class InterFaceMixin:
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
class AdapterLayer(nn.Module, InterFaceMixin):
|
||||
r"""A layer of adapter tuning module.
|
||||
"""
|
||||
|
|
|
@ -91,10 +91,10 @@ class SaveLoadMixin:
|
|||
state_dict: Optional[dict] = None,
|
||||
save_function: Callable = torch.save,
|
||||
push_to_dc: bool = True,
|
||||
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
|
||||
center_args_pool: Optional[dict] = None,
|
||||
list_tags: Optional[List] = None,
|
||||
dict_tags: Optional[Dict] = None,
|
||||
center_args: Optional[Union[DeltaCenterArguments, dict]] = dict(),
|
||||
center_args_pool: Optional[dict] = dict(),
|
||||
list_tags: Optional[List] = list(),
|
||||
dict_tags: Optional[Dict] = dict(),
|
||||
delay_push: bool = False,
|
||||
):
|
||||
r"""
|
||||
|
@ -106,10 +106,11 @@ class SaveLoadMixin:
|
|||
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
|
||||
which is a subdirectory of the current working directory.
|
||||
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
|
||||
model file.
|
||||
model file. if ``False``, only the state dict will be saved.
|
||||
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
|
||||
state_dict is loaded from the backbone model's trainable parameters.
|
||||
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
|
||||
state_dict_only: (optional) if ``True``, only the state_dict will be saved.
|
||||
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
|
||||
This includes:
|
||||
- creating a configuration file for the model
|
||||
|
@ -131,7 +132,8 @@ class SaveLoadMixin:
|
|||
self.create_config_from_model()
|
||||
self.add_configs_when_saving()
|
||||
|
||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||
if push_to_dc:
|
||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||
center_args_pool=center_args_pool)
|
||||
|
||||
save_directory = finetuned_delta_path
|
||||
|
@ -140,8 +142,10 @@ class SaveLoadMixin:
|
|||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
save_directory = os.path.join(save_directory, final_center_args.name)
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_dc:
|
||||
save_directory = os.path.join(save_directory, final_center_args.name)
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = self.backbone_model# unwrap_model(self)
|
||||
|
||||
|
@ -149,13 +153,13 @@ class SaveLoadMixin:
|
|||
if state_dict is None:
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Save the config
|
||||
if save_config:
|
||||
self.config.save_finetuned(save_directory)
|
||||
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
save_function(state_dict, output_model_file)
|
||||
|
||||
# Save the config
|
||||
if save_config:
|
||||
self.config.save_finetuned(save_directory)
|
||||
|
||||
|
||||
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
|
||||
)
|
||||
|
@ -164,24 +168,36 @@ class SaveLoadMixin:
|
|||
logger.info("Creating yaml file for delta center")
|
||||
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
|
||||
|
||||
|
||||
if not delay_push:
|
||||
OssClient.upload(base_dir=save_directory)
|
||||
else:
|
||||
logger.info("Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||
+"*"*30)
|
||||
|
||||
# get absolute path of saved_directory,
|
||||
if not delay_push:
|
||||
OssClient.upload(base_dir=save_directory)
|
||||
else:
|
||||
logger.info(f"Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||
+"*"*30)
|
||||
|
||||
|
||||
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
|
||||
|
||||
|
||||
def create_yml(self, save_dir, config, list_tags=list(), dict_tags=dict()):
|
||||
f = open("{}/config.yml".format(save_dir), 'w')
|
||||
config_dict = vars(config)
|
||||
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {}
|
||||
config_dict['list_tags'] = list_tags if list_tags is not None else []
|
||||
config_dict['dict_tags'] = dict_tags
|
||||
config_dict['list_tags'] = list_tags
|
||||
yaml.safe_dump(config_dict, f)
|
||||
f.close()
|
||||
|
||||
def load_checkpoint(self, path, load_func=torch.load, backbone_model=None):
|
||||
r"""Simple method for loading only the checkpoint
|
||||
"""
|
||||
if backbone_model is None:
|
||||
backbone_model = self.backbone_model
|
||||
self.backbone_model.load_state_dict(load_func(f"{path}/{WEIGHTS_NAME}"), strict=False)
|
||||
|
||||
def save_checkpoint(self, path, save_func=torch.save, backbone_model=None):
|
||||
r"""Simple method for saving only the checkpoint"""
|
||||
if backbone_model is None:
|
||||
backbone_model = self.backbone_model
|
||||
save_func(backbone_model.state_dict(), f"{path}/{WEIGHTS_NAME}")
|
||||
|
||||
@classmethod
|
||||
def from_finetuned(cls,
|
||||
finetuned_delta_path: Optional[Union[str, os.PathLike]],
|
||||
|
|
|
@ -24,7 +24,7 @@ t5_mapping = {
|
|||
}
|
||||
}
|
||||
},
|
||||
"final_layer_norm": {"__name__":"layer_norm"},
|
||||
"final_layer_norm": {"__name__":"layer_norm"},
|
||||
},
|
||||
"decoder": {"__name__":"decoder",
|
||||
"embed_tokens": {"__name__":"embeddings"},
|
||||
|
@ -199,8 +199,14 @@ distilbert_mapping = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
MAPPINGERROR_MSG = "We haven't provide common structure mapping for this backbone model." + \
|
||||
" If it is a common enough PLM, please check whether it is wrapped by other wrapper model, e.g., XXXForSequenceClassification." +\
|
||||
"Please manually add the "+\
|
||||
"delta models by speicifying 'modified_modules' based on the visualization of model structure. Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail."
|
||||
|
||||
def transform(org_key, mapping, strict=True, warning=False, verbose=False):
|
||||
|
||||
|
||||
chain = org_key.split(".")
|
||||
query = ""
|
||||
node = mapping
|
||||
|
@ -215,7 +221,7 @@ def transform(org_key, mapping, strict=True, warning=False, verbose=False):
|
|||
if strict:
|
||||
if warning:
|
||||
print(f"'{org_key}' has no common mapping.")
|
||||
return
|
||||
return
|
||||
else:
|
||||
new_chain.append(query)
|
||||
else:
|
||||
|
@ -226,19 +232,19 @@ def transform(org_key, mapping, strict=True, warning=False, verbose=False):
|
|||
new_chain.append(query)
|
||||
query = ""
|
||||
else:
|
||||
query += "."
|
||||
query += "."
|
||||
if query!="":
|
||||
if strict:
|
||||
if warning:
|
||||
print("A part of the orginial key hasn't been matched!")
|
||||
return
|
||||
return
|
||||
else:
|
||||
new_chain.append(query.strip(".")) # tailing query
|
||||
new_key = ".".join(new_chain)
|
||||
if verbose:
|
||||
print(f"{org_key} => {new_key}")
|
||||
return new_key
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -255,7 +261,7 @@ def mapping_for_SequenceClassification(mapping, type):
|
|||
mapping["classifier"] = {"__name__": "classifier"}
|
||||
elif type == "deberta":
|
||||
mapping.pop("lm_predictions.lm_head")
|
||||
mapping["pooler"] = {"__name__": "classifier"}
|
||||
mapping["pooler"] = {"__name__": "classifier"}
|
||||
mapping["classifier"] = {"__name__": "classifier"}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -266,29 +272,30 @@ def mapping_for_ConditionalGeneration(mapping, type):
|
|||
if type == "t5":
|
||||
mapping["lm_head"] = {"__name__":"lm_head.proj"}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(MAPPINGERROR_MSG.format())
|
||||
return mapping
|
||||
|
||||
class _LazyLoading(OrderedDict):
|
||||
def __init__(self, mapping):
|
||||
self._mapping_string = mapping
|
||||
self._mapping = {}
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self._mapping_string:
|
||||
raise KeyError(key)
|
||||
raise KeyError(MAPPINGERROR_MSG)
|
||||
value = self._mapping_string[key]
|
||||
self._mapping[key] = eval(value)
|
||||
return self._mapping[key]
|
||||
|
||||
return self._mapping[key]
|
||||
|
||||
def keys(self):
|
||||
return list(self._mapping_string.keys())
|
||||
|
||||
|
||||
def __contains__(self, item):
|
||||
|
||||
return item in self._mapping_string
|
||||
|
||||
|
||||
|
||||
class CommonStructureMap(object):
|
||||
r""" A lazy loading structure map.
|
||||
"""
|
||||
|
@ -296,7 +303,6 @@ class CommonStructureMap(object):
|
|||
"RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""",
|
||||
"RobertaForMaskedLM": "roberta_mapping",
|
||||
"BertForMaskedLM": "bert_mapping",
|
||||
"BertForSequenceClassification": """mapping_for_SequenceClassification(bert_mapping, "bert")""",
|
||||
"T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""",
|
||||
"DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")"""
|
||||
})
|
||||
|
@ -315,8 +321,17 @@ class CommonStructureMap(object):
|
|||
"""
|
||||
backbone_class = type(backbone_model).__name__
|
||||
if backbone_class not in cls.Mappings:
|
||||
raise KeyError(backbone_class)
|
||||
mapping = cls.Mappings[backbone_class]
|
||||
raise KeyError(MAPPINGERROR_MSG)
|
||||
|
||||
try:
|
||||
mapping = cls.Mappings[backbone_class]
|
||||
except KeyError:
|
||||
logger.error(MAPPINGERROR_MSG)
|
||||
exit(-1)
|
||||
|
||||
|
||||
|
||||
|
||||
if visualize:
|
||||
logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")
|
||||
vis = Visualization(backbone_model)
|
||||
|
@ -346,4 +361,3 @@ if __name__ == "__main__":
|
|||
|
||||
for name, _ in plm.named_modules():
|
||||
transform(name, t5_mapping, strict=True, warning=False)
|
||||
|
32
setup.py
32
setup.py
|
@ -3,22 +3,34 @@ import setuptools
|
|||
import os
|
||||
import os
|
||||
|
||||
def get_requirements(path):
|
||||
ret = []
|
||||
# path = "/home/ubuntu/OpenDelta_beta/OpenDelta/"
|
||||
with open(os.path.join(path, "opendelta.egg-info/requires.txt"), encoding="utf-8") as freq:
|
||||
for line in freq.readlines():
|
||||
ret.append( line.strip() )
|
||||
|
||||
requires = """torch>=1.8.0
|
||||
transformers>=4.10.0
|
||||
datasets==1.17.0
|
||||
sentencepiece>=0.1.96
|
||||
tqdm>=4.62.2
|
||||
# loralib
|
||||
decorator
|
||||
rich
|
||||
web.py
|
||||
gitpython
|
||||
delta_center_client
|
||||
"""
|
||||
|
||||
def get_requirements():
|
||||
ret = [x for x in requires.split("\n") if len(x)>0]
|
||||
print("requirements:", ret)
|
||||
return ret
|
||||
|
||||
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
requires = get_requirements(path)
|
||||
|
||||
# path = os.path.dirname(os.path.abspath(__file__))
|
||||
# requires = get_requirements(path)
|
||||
|
||||
with open('README.md', 'r') as f:
|
||||
setuptools.setup(
|
||||
name = 'opendelta',
|
||||
version = "0.2.0",
|
||||
version = "0.2.2",
|
||||
description = "An open source framework for delta learning (parameter efficient learning).",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
@ -28,7 +40,7 @@ with open('README.md', 'r') as f:
|
|||
url="https://github.com/thunlp/OpenDelta",
|
||||
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
|
||||
python_requires=">=3.6.0",
|
||||
install_requires=get_requirements(path),
|
||||
install_requires=get_requirements(),
|
||||
package_dir={'opendelta':'opendelta'},
|
||||
package_data= {
|
||||
'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'],
|
||||
|
|
Loading…
Reference in New Issue