small update in save_and_load
This commit is contained in:
parent
ed2bd8c50f
commit
b9a0f7cf89
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
__version__ = "0.2.2"
|
||||||
|
|
||||||
class GlobalSetting:
|
class GlobalSetting:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -193,6 +193,7 @@ class BaseDeltaConfig:
|
||||||
config_dict.pop(config_key)
|
config_dict.pop(config_key)
|
||||||
unused_config_keys.append(config_key)
|
unused_config_keys.append(config_key)
|
||||||
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
|
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
|
||||||
|
|
||||||
config = cls(**config_dict)
|
config = cls(**config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,10 +22,18 @@ class InterFaceMixin:
|
||||||
self._reverse_axis_order = np.argsort(self._axis_order).tolist()
|
self._reverse_axis_order = np.argsort(self._axis_order).tolist()
|
||||||
|
|
||||||
def _transpose(self, tensor):
|
def _transpose(self, tensor):
|
||||||
|
if tensor.dim() == 3:
|
||||||
return tensor.permute(*self._axis_order)
|
return tensor.permute(*self._axis_order)
|
||||||
|
else:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _reverse_transpose(self, tensor):
|
def _reverse_transpose(self, tensor):
|
||||||
|
if tensor.dim() == 3:
|
||||||
return tensor.permute(*self._reverse_axis_order).contiguous()
|
return tensor.permute(*self._reverse_axis_order).contiguous()
|
||||||
|
else:
|
||||||
|
return tensor
|
||||||
|
|
||||||
def _convert_data_type(self, tensor):
|
def _convert_data_type(self, tensor):
|
||||||
self._data_type_record = tensor.dtype
|
self._data_type_record = tensor.dtype
|
||||||
|
@ -37,6 +45,8 @@ class InterFaceMixin:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AdapterLayer(nn.Module, InterFaceMixin):
|
class AdapterLayer(nn.Module, InterFaceMixin):
|
||||||
r"""A layer of adapter tuning module.
|
r"""A layer of adapter tuning module.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -91,10 +91,10 @@ class SaveLoadMixin:
|
||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
save_function: Callable = torch.save,
|
save_function: Callable = torch.save,
|
||||||
push_to_dc: bool = True,
|
push_to_dc: bool = True,
|
||||||
center_args: Optional[Union[DeltaCenterArguments, dict]] = None,
|
center_args: Optional[Union[DeltaCenterArguments, dict]] = dict(),
|
||||||
center_args_pool: Optional[dict] = None,
|
center_args_pool: Optional[dict] = dict(),
|
||||||
list_tags: Optional[List] = None,
|
list_tags: Optional[List] = list(),
|
||||||
dict_tags: Optional[Dict] = None,
|
dict_tags: Optional[Dict] = dict(),
|
||||||
delay_push: bool = False,
|
delay_push: bool = False,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
|
@ -106,10 +106,11 @@ class SaveLoadMixin:
|
||||||
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
|
If not specified, the model will be saved in the directory ``./delta_checkpoints/``,
|
||||||
which is a subdirectory of the current working directory.
|
which is a subdirectory of the current working directory.
|
||||||
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
|
save_config: (optional) if ``True``, the configuration file will be saved in the same directory as the
|
||||||
model file.
|
model file. if ``False``, only the state dict will be saved.
|
||||||
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
|
state_dict: (optional) a dictionary containing the model's state_dict. If not specified, the
|
||||||
state_dict is loaded from the backbone model's trainable parameters.
|
state_dict is loaded from the backbone model's trainable parameters.
|
||||||
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
|
save_function: (optional) the function used to save the model. Defaults to ``torch.save``.
|
||||||
|
state_dict_only: (optional) if ``True``, only the state_dict will be saved.
|
||||||
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
|
push_to_dc: (optional) if ``True``, the model will prepare things to pushed to the DeltaCenter.
|
||||||
This includes:
|
This includes:
|
||||||
- creating a configuration file for the model
|
- creating a configuration file for the model
|
||||||
|
@ -131,6 +132,7 @@ class SaveLoadMixin:
|
||||||
self.create_config_from_model()
|
self.create_config_from_model()
|
||||||
self.add_configs_when_saving()
|
self.add_configs_when_saving()
|
||||||
|
|
||||||
|
if push_to_dc:
|
||||||
final_center_args = self.create_delta_center_args(center_args=center_args,
|
final_center_args = self.create_delta_center_args(center_args=center_args,
|
||||||
center_args_pool=center_args_pool)
|
center_args_pool=center_args_pool)
|
||||||
|
|
||||||
|
@ -140,6 +142,8 @@ class SaveLoadMixin:
|
||||||
return
|
return
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
if push_to_dc:
|
||||||
save_directory = os.path.join(save_directory, final_center_args.name)
|
save_directory = os.path.join(save_directory, final_center_args.name)
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
@ -149,13 +153,13 @@ class SaveLoadMixin:
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = model_to_save.state_dict()
|
state_dict = model_to_save.state_dict()
|
||||||
|
|
||||||
|
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||||
|
save_function(state_dict, output_model_file)
|
||||||
|
|
||||||
# Save the config
|
# Save the config
|
||||||
if save_config:
|
if save_config:
|
||||||
self.config.save_finetuned(save_directory)
|
self.config.save_finetuned(save_directory)
|
||||||
|
|
||||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
|
||||||
save_function(state_dict, output_model_file)
|
|
||||||
|
|
||||||
|
|
||||||
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
|
logger.info("\n"+"*"*30+f"\nYou delta models has been saved locally to:\n\t{os.path.abspath(save_directory)}"
|
||||||
)
|
)
|
||||||
|
@ -164,24 +168,36 @@ class SaveLoadMixin:
|
||||||
logger.info("Creating yaml file for delta center")
|
logger.info("Creating yaml file for delta center")
|
||||||
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
|
self.create_yml(save_directory, final_center_args, list_tags, dict_tags)
|
||||||
|
|
||||||
|
|
||||||
if not delay_push:
|
if not delay_push:
|
||||||
OssClient.upload(base_dir=save_directory)
|
OssClient.upload(base_dir=save_directory)
|
||||||
else:
|
else:
|
||||||
logger.info("Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
logger.info(f"Delay push: you can push it to the delta center later using \n\tpython -m DeltaCenter upload {os.path.abspath(save_directory)}\n"
|
||||||
+"*"*30)
|
+"*"*30)
|
||||||
|
|
||||||
# get absolute path of saved_directory,
|
|
||||||
|
|
||||||
|
|
||||||
def create_yml(self, save_dir, config, list_tags=None, dict_tags=None):
|
|
||||||
|
def create_yml(self, save_dir, config, list_tags=list(), dict_tags=dict()):
|
||||||
f = open("{}/config.yml".format(save_dir), 'w')
|
f = open("{}/config.yml".format(save_dir), 'w')
|
||||||
config_dict = vars(config)
|
config_dict = vars(config)
|
||||||
config_dict['dict_tags'] = dict_tags if dict_tags is not None else {}
|
config_dict['dict_tags'] = dict_tags
|
||||||
config_dict['list_tags'] = list_tags if list_tags is not None else []
|
config_dict['list_tags'] = list_tags
|
||||||
yaml.safe_dump(config_dict, f)
|
yaml.safe_dump(config_dict, f)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
def load_checkpoint(self, path, load_func=torch.load, backbone_model=None):
|
||||||
|
r"""Simple method for loading only the checkpoint
|
||||||
|
"""
|
||||||
|
if backbone_model is None:
|
||||||
|
backbone_model = self.backbone_model
|
||||||
|
self.backbone_model.load_state_dict(load_func(f"{path}/{WEIGHTS_NAME}"), strict=False)
|
||||||
|
|
||||||
|
def save_checkpoint(self, path, save_func=torch.save, backbone_model=None):
|
||||||
|
r"""Simple method for saving only the checkpoint"""
|
||||||
|
if backbone_model is None:
|
||||||
|
backbone_model = self.backbone_model
|
||||||
|
save_func(backbone_model.state_dict(), f"{path}/{WEIGHTS_NAME}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_finetuned(cls,
|
def from_finetuned(cls,
|
||||||
finetuned_delta_path: Optional[Union[str, os.PathLike]],
|
finetuned_delta_path: Optional[Union[str, os.PathLike]],
|
||||||
|
|
|
@ -199,6 +199,12 @@ distilbert_mapping = {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MAPPINGERROR_MSG = "We haven't provide common structure mapping for this backbone model." + \
|
||||||
|
" If it is a common enough PLM, please check whether it is wrapped by other wrapper model, e.g., XXXForSequenceClassification." +\
|
||||||
|
"Please manually add the "+\
|
||||||
|
"delta models by speicifying 'modified_modules' based on the visualization of model structure. Refer to `https://opendelta.readthedocs.io/en/latest/notes/faq.html` for detail."
|
||||||
|
|
||||||
def transform(org_key, mapping, strict=True, warning=False, verbose=False):
|
def transform(org_key, mapping, strict=True, warning=False, verbose=False):
|
||||||
|
|
||||||
chain = org_key.split(".")
|
chain = org_key.split(".")
|
||||||
|
@ -266,7 +272,7 @@ def mapping_for_ConditionalGeneration(mapping, type):
|
||||||
if type == "t5":
|
if type == "t5":
|
||||||
mapping["lm_head"] = {"__name__":"lm_head.proj"}
|
mapping["lm_head"] = {"__name__":"lm_head.proj"}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError(MAPPINGERROR_MSG.format())
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
class _LazyLoading(OrderedDict):
|
class _LazyLoading(OrderedDict):
|
||||||
|
@ -276,7 +282,7 @@ class _LazyLoading(OrderedDict):
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if key not in self._mapping_string:
|
if key not in self._mapping_string:
|
||||||
raise KeyError(key)
|
raise KeyError(MAPPINGERROR_MSG)
|
||||||
value = self._mapping_string[key]
|
value = self._mapping_string[key]
|
||||||
self._mapping[key] = eval(value)
|
self._mapping[key] = eval(value)
|
||||||
return self._mapping[key]
|
return self._mapping[key]
|
||||||
|
@ -289,6 +295,7 @@ class _LazyLoading(OrderedDict):
|
||||||
return item in self._mapping_string
|
return item in self._mapping_string
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CommonStructureMap(object):
|
class CommonStructureMap(object):
|
||||||
r""" A lazy loading structure map.
|
r""" A lazy loading structure map.
|
||||||
"""
|
"""
|
||||||
|
@ -296,7 +303,6 @@ class CommonStructureMap(object):
|
||||||
"RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""",
|
"RobertaForSequenceClassification": """mapping_for_SequenceClassification(roberta_mapping, "roberta")""",
|
||||||
"RobertaForMaskedLM": "roberta_mapping",
|
"RobertaForMaskedLM": "roberta_mapping",
|
||||||
"BertForMaskedLM": "bert_mapping",
|
"BertForMaskedLM": "bert_mapping",
|
||||||
"BertForSequenceClassification": """mapping_for_SequenceClassification(bert_mapping, "bert")""",
|
|
||||||
"T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""",
|
"T5ForConditionalGeneration": """mapping_for_ConditionalGeneration(t5_mapping, "t5")""",
|
||||||
"DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")"""
|
"DebertaV2ForSequenceClassification": """mapping_for_SequenceClassification(debertav2_mapping, "deberta")"""
|
||||||
})
|
})
|
||||||
|
@ -315,8 +321,17 @@ class CommonStructureMap(object):
|
||||||
"""
|
"""
|
||||||
backbone_class = type(backbone_model).__name__
|
backbone_class = type(backbone_model).__name__
|
||||||
if backbone_class not in cls.Mappings:
|
if backbone_class not in cls.Mappings:
|
||||||
raise KeyError(backbone_class)
|
raise KeyError(MAPPINGERROR_MSG)
|
||||||
|
|
||||||
|
try:
|
||||||
mapping = cls.Mappings[backbone_class]
|
mapping = cls.Mappings[backbone_class]
|
||||||
|
except KeyError:
|
||||||
|
logger.error(MAPPINGERROR_MSG)
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if visualize:
|
if visualize:
|
||||||
logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")
|
logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")
|
||||||
vis = Visualization(backbone_model)
|
vis = Visualization(backbone_model)
|
||||||
|
@ -346,4 +361,3 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
for name, _ in plm.named_modules():
|
for name, _ in plm.named_modules():
|
||||||
transform(name, t5_mapping, strict=True, warning=False)
|
transform(name, t5_mapping, strict=True, warning=False)
|
||||||
|
|
32
setup.py
32
setup.py
|
@ -3,22 +3,34 @@ import setuptools
|
||||||
import os
|
import os
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def get_requirements(path):
|
|
||||||
ret = []
|
requires = """torch>=1.8.0
|
||||||
# path = "/home/ubuntu/OpenDelta_beta/OpenDelta/"
|
transformers>=4.10.0
|
||||||
with open(os.path.join(path, "opendelta.egg-info/requires.txt"), encoding="utf-8") as freq:
|
datasets==1.17.0
|
||||||
for line in freq.readlines():
|
sentencepiece>=0.1.96
|
||||||
ret.append( line.strip() )
|
tqdm>=4.62.2
|
||||||
|
# loralib
|
||||||
|
decorator
|
||||||
|
rich
|
||||||
|
web.py
|
||||||
|
gitpython
|
||||||
|
delta_center_client
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_requirements():
|
||||||
|
ret = [x for x in requires.split("\n") if len(x)>0]
|
||||||
|
print("requirements:", ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
path = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
requires = get_requirements(path)
|
# path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# requires = get_requirements(path)
|
||||||
|
|
||||||
with open('README.md', 'r') as f:
|
with open('README.md', 'r') as f:
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name = 'opendelta',
|
name = 'opendelta',
|
||||||
version = "0.2.0",
|
version = "0.2.2",
|
||||||
description = "An open source framework for delta learning (parameter efficient learning).",
|
description = "An open source framework for delta learning (parameter efficient learning).",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
@ -28,7 +40,7 @@ with open('README.md', 'r') as f:
|
||||||
url="https://github.com/thunlp/OpenDelta",
|
url="https://github.com/thunlp/OpenDelta",
|
||||||
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
|
keywords = ['PLM', 'Parameter-efficient-Learning', 'AI', 'NLP'],
|
||||||
python_requires=">=3.6.0",
|
python_requires=">=3.6.0",
|
||||||
install_requires=get_requirements(path),
|
install_requires=get_requirements(),
|
||||||
package_dir={'opendelta':'opendelta'},
|
package_dir={'opendelta':'opendelta'},
|
||||||
package_data= {
|
package_data= {
|
||||||
'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'],
|
'opendelta':["utils/interactive/templates/*.html", 'requirments.txt'],
|
||||||
|
|
Loading…
Reference in New Issue