OpenDeltaMirror/opendelta/delta_configs.py

387 lines
18 KiB
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
import os
import re
from typing import Union, Dict, Any, Tuple, Optional
from opendelta import __version__ as opendelta_version
from opendelta.utils import logging
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
import transformers
import json
import copy
CONFIG_NAME = "config.json"
transformers_version = transformers.__version__
checked_package_versions = ["transformers_version", "opendelta_version"]
logger = logging.get_logger(__name__)
FULL_CONFIGURATION_FILE = "config.json"
_re_configuration_file = re.compile(r"config\.(.*)\.json")
2022-07-03 10:10:18 +08:00
class BaseDeltaConfig:
2022-04-14 11:22:41 +08:00
r"""Base class for all configuration classes. Handles a few
2022-02-14 21:19:03 +08:00
parameters common to all delta models' configurations as well as methods for loading/downloading/saving configurations.
Class attributes (overridden by derived classes):
- **delta_type** :obj:`str`) -- the name of the delta modules, used to create the correct :py:class:`~opendelta.AutoConfig`.
Args:
modified_modules (:obj:`List[str]`, *optional*, defaults to :obj:``None``)
2022-04-14 11:22:41 +08:00
The list of keys to determine which modules you want to modify. OpenDelta will take every modulees that
**ends with** the one of the provided keys as the modification target. When not given any value, i.e.
``modified_modules=None``, the delta module will use the it corresponding default modification modules.
2022-02-14 21:19:03 +08:00
Taking DistilBertModel with an classifier on top as an example:
2022-04-14 11:22:41 +08:00
.. note::
2022-02-14 21:19:03 +08:00
**Examples**: When adding delta to DistilBertModel,
2022-04-14 11:22:41 +08:00
1. set to ``["0.attention.out_lin"]`` will add delta modules to the attention output of distilbert's
2022-02-14 21:19:03 +08:00
ayer 0, i.e., ``distilbert.transformer.layer.0.attention.out_lin``.
2. set to ``["attention.out_lin"]`` will add the delta modules in every layer's ``attention.out_lin``.
2022-04-14 11:22:41 +08:00
unfrozen_modules (:obj:`List[str]`, *optional*, defaults to :obj:`["deltas"]` )
exclude_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules starts with these strings will
be excluded in modification. Note that currently only plain text (no regular expression) is supported.
The modules that are unfrozen
during training. Including the ones that are newly introduced as delta modules, and the ones that are
originally a part of the model but set to trainable (:obj:`requires_grad=True`) to train together with the
delta modules. OpenDelta will take every modules that **ends with** the one of the provided keys and all
its sub-modules and paramters as trainable.
.. note::
2022-02-14 21:19:03 +08:00
**Examples**: When adding delta to DistilBertModel,
2022-04-14 11:22:41 +08:00
1. set this argument to ``["bias"]`` will make all bias terms tunable.
2022-02-14 21:19:03 +08:00
2. set this argument to ``["attention"]`` will make all parameters in all attention modules tunable.
3. set this argument to ``["deltas"]`` will make all the parameters in the newly introduced delta
2022-04-14 11:22:41 +08:00
modules tunable.
2022-02-14 21:19:03 +08:00
4. set this argument to ``["classifier"]`` will make all parameters in the classifier tunable.
2022-04-14 11:22:41 +08:00
5. set this argument to ``["3.ffn.lin2", "deltas", "classifier"]``, will make all parameters in
2022-02-14 21:19:03 +08:00
the third layer's feed forward layer's send linear layer, the detla modules, and the classifiers modules
2022-04-14 11:22:41 +08:00
tunable.
2022-02-14 21:19:03 +08:00
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): Whether using the common structure mapping of
the transformer model when designating :obj:`modified_modules` and :obj:`unfrozen_modules`.
backbone_class (:obj:`str`, *optional*, default to :obj:`None`): The name of backbone model's class, e.g.
2022-04-14 11:22:41 +08:00
``RobertaForMaskedLM``. Saving this infomation let the users explicitly know on which backbone the
delta model is trained.
2022-02-14 21:19:03 +08:00
backbone_checkpoint_name (:obj:`str`, *optional*, default to :obj:`None`): The specific checkpoint of the model.
2022-04-14 11:22:41 +08:00
In ideal case, it should be the url to download the checkpoint. However, we do not force the user to
2022-02-14 21:19:03 +08:00
specify a downloadable url here.
2022-04-14 11:22:41 +08:00
backbone_hash (:obj:`str`, *optional*, default to :obj:`None`): The md5-hash of the backbone model. It is
calculated using the string representation of the model and the sequential expansion of all the
parameters in the model. When loading a delta checkpoint in strict mode, the hash of the backbone model
will be compared to the hash in this config.
2022-02-14 21:19:03 +08:00
"""
delta_type: str = ""
2022-04-14 11:22:41 +08:00
def __init__(self,
2022-02-14 21:19:03 +08:00
modified_modules = None,
2022-04-14 11:22:41 +08:00
exclude_modules = None,
2022-02-14 21:19:03 +08:00
unfrozen_modules = ["deltas"],
common_structure=False,
backbone_class = None,
backbone_checkpoint_name = None,
backbone_hash = None,
):
arg_names = get_arg_names(BaseDeltaConfig.__init__)
for arg_name in arg_names:
setattr(self, arg_name, locals()[arg_name])
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
2022-07-03 10:10:18 +08:00
def from_finetuned(cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
2022-02-14 21:19:03 +08:00
r"""
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
Args:
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): This can be either:
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
* a string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
deltahub.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
* a path to a *directory* containing a configuration file saved using the :meth:`BaseDeltaConfig.save_finetuned` method, e.g., ``./my_model_directory/``.
* a path or url to a saved configuration JSON *file*, e.g., ``./my_model_directory/configuration.json``.
cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained delta model configuration should be cached if the
standard cache should not be used.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
.. code-block:: python
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
"""
2022-07-03 10:10:18 +08:00
config_dict, kwargs = cls.get_config_dict(finetuned_delta_path, **kwargs)
2022-02-14 21:19:03 +08:00
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warn(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def save_finetuned(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory :obj:`save_directory`, so that it can be re-loaded using the
:meth:`BaseDeltaConfig.from_finetuned` class method.
Args:
2022-04-14 11:22:41 +08:00
save_directory (:obj:`str` or :obj:`os.PathLike`): Directory where the configuration JSON file
2022-02-14 21:19:03 +08:00
will be saved (will be created if it does not exist).
2022-04-14 11:22:41 +08:00
push_to_hub (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to push your model to
2022-02-14 21:19:03 +08:00
the Hugging Face model hub after saving it.
.. warning::
1. Will raise error if you haven't config a Huggingface Model Hub.
2. Using ``push_to_hub=True`` will synchronize the repository you are pushing to with ``save_directory``,
which requires ``save_directory`` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along ``temp_dir=True`` to use a temporary directory instead.
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
kwargs:
2022-04-14 11:22:41 +08:00
Additional key word arguments passed along to the
2022-02-14 21:19:03 +08:00
`PushToHubMixin.push_to_hub <https://huggingface.co/docs/transformers/master/main_classes/model#transformers.file_utils.PushToHubMixin.push_to_hub>`_ method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
self.to_json_file(output_config_file, use_diff=True)
logger.info(f"Configuration saved in {output_config_file}")
if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Configuration pushed to the hub in this commit: {url}")
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "BaseDeltaConfig":
r"""
Instantiate a :obj:`BaseDeltaConfig` from a python dictionary of parameters.
Args:
config_dict (:obj:`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the :py:meth:`~PretrainedConfig.get_config_dict` method.
kwargs (:obj:`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
:obj:`BaseDeltaConfig`: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
accept_args = get_arg_names(cls.__init__) + get_arg_names(BaseDeltaConfig.__init__)
unused_config_keys = []
for config_key in list(config_dict.keys()):
if config_key not in accept_args:
config_dict.pop(config_key)
unused_config_keys.append(config_key)
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
2022-07-06 22:00:58 +08:00
2022-02-14 21:19:03 +08:00
config = cls(**config_dict)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
setattr(config, key, value)
if key != "torch_dtype":
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
2022-07-03 10:10:18 +08:00
logger.info(f"Model config\n{config}")
2022-02-14 21:19:03 +08:00
if return_unused_kwargs:
return config, kwargs
else:
return config
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
def get_config_dict(
2022-07-03 10:10:18 +08:00
cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs
2022-02-14 21:19:03 +08:00
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""[NODOC]
2022-07-03 10:10:18 +08:00
From a ``finetuned_delta_path``, resolve to a dictionary of parameters, to be used for instantiating a
2022-02-14 21:19:03 +08:00
[``PretrainedConfig``] using ``from_dict``.
Parameters:
2022-07-03 10:10:18 +08:00
finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`):
2022-02-14 21:19:03 +08:00
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns:
:obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
"""
2022-07-03 10:10:18 +08:00
cache_dir = kwargs.get("cache_dir", None)
force_download = kwargs.get("force_download", False)
# resume_download = kwargs.pop("resume_download", False)
# proxies = kwargs.pop("proxies", None)
# use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.get("local_files_only", False)
# revision = kwargs.pop("revision", None)
2022-02-14 21:19:03 +08:00
# from_pipeline = kwargs.pop("_from_pipeline", None)
2022-07-03 10:10:18 +08:00
# from_auto_class = kwargs.pop("_from_auto", False)
2022-02-14 21:19:03 +08:00
2022-07-03 10:10:18 +08:00
# user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
2022-02-14 21:19:03 +08:00
# if from_pipeline is not None:
# user_agent["using_pipeline"] = from_pipeline
2022-07-03 10:10:18 +08:00
if os.environ.get("DELTACENTER_OFFLINE", '0') == '1':
logger.info("Delta Center offline mode!")
2022-02-14 21:19:03 +08:00
local_files_only = True
2022-07-03 10:10:18 +08:00
finetuned_delta_path = str(finetuned_delta_path)
2022-07-01 22:23:02 +08:00
2022-07-03 10:10:18 +08:00
if cache_dir is not None:
cached_finetuned_delta_path = os.path.join(cache_dir, finetuned_delta_path)
2022-02-14 21:19:03 +08:00
else:
2022-07-03 10:10:18 +08:00
cached_finetuned_delta_path = finetuned_delta_path
2022-02-14 21:19:03 +08:00
2022-07-03 10:10:18 +08:00
if os.path.isfile(cached_finetuned_delta_path):
local_files_only = True
elif os.path.isdir(cached_finetuned_delta_path):
# cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
local_files_only = True
2022-02-14 21:19:03 +08:00
2022-07-03 10:10:18 +08:00
# if local_files_only:
# config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
if not local_files_only or force_download:
from .utils.delta_center import download as dcdownload
# try to download from DeltaCenter
cached_finetuned_delta_path = dcdownload(finetuned_delta_path, force_download=force_download, cache_dir=cache_dir)
kwargs['force_download'] = False # Has been downloaded, not more forcing
2022-02-14 21:19:03 +08:00
2022-07-03 10:10:18 +08:00
cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
2022-02-14 21:19:03 +08:00
return config_dict, kwargs
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def __eq__(self, other):
return self.__dict__ == other.__dict__
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def to_json_string(self, use_diff: bool = True) -> str:
"""[NODOC]
Serializes this instance to a JSON string.
Args:
use_diff (:obj:`bool`, *optional*, defaults to :obj:`True`):
If set to :obj:`True`, only the difference between the config instance and the default ``PretrainedConfig()``
is serialized to JSON string.
Returns:
:obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
"""[NODOC]
Save this instance to a JSON file.
Args:
json_file_path (:obj:`str` or :obj:`os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (:obj:`bool`, *optional*, defaults to :obj:`True`):
If set to :obj:`True`, only the difference between the config instance and the default ``PretrainedConfig()``
is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def to_diff_dict(self) -> Dict[str, Any]:
"""[NODOC]
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = BaseDeltaConfig().to_dict()
# get class specific config dict
class_config_dict = self.__class__().to_dict() #if not self.is_composition else {}
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if (
key not in default_config_dict
or key in checked_package_versions
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
serializable_config_dict[key] = value
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update(self, config_dict: Dict[str, Any]):
"""[NODOC]
Updates attributes of this class with attributes from ``config_dict``.
Args:
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
:obj:`dict`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
# Transformers version when serializing the model
output["transformers_version"] = transformers_version
output["opendelta_version"] = opendelta_version
self.dict_torch_dtype_to_str(output)
return output
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""[NODOC]
Checks whether the passed dictionary has a *torch_dtype* key and if it's not None, converts torch.dtype to a
string of just the type. For example, ``torch.float32`` get converted into *"float32"* string, which can then be
stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if __name__ == "__main__":
myconfig = BaseDeltaConfig.from_pretrained("../ckpts/lora/")
myconfig.save_pretrained("../ckpts/lora.1/")
print(myconfig)