2022-02-14 21:19:03 +08:00
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
from typing import Union, Dict, Any, Tuple, Optional
|
|
|
|
|
from opendelta import __version__ as opendelta_version
|
|
|
|
|
from opendelta.utils import logging
|
|
|
|
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
|
|
|
|
import transformers
|
|
|
|
|
import json
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
CONFIG_NAME = "config.json"
|
|
|
|
|
transformers_version = transformers.__version__
|
|
|
|
|
|
|
|
|
|
checked_package_versions = ["transformers_version", "opendelta_version"]
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
FULL_CONFIGURATION_FILE = "config.json"
|
|
|
|
|
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
class BaseDeltaConfig:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
r"""Base class for all configuration classes. Handles a few
|
2022-02-14 21:19:03 +08:00
|
|
|
|
parameters common to all delta models' configurations as well as methods for loading/downloading/saving configurations.
|
|
|
|
|
|
|
|
|
|
Class attributes (overridden by derived classes):
|
|
|
|
|
|
|
|
|
|
- **delta_type** (:obj:`str`) -- the name of the delta modules, used to create the correct :py:class:`~opendelta.AutoConfig`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
modified_modules (:obj:`List[str]`, *optional*, defaults to :obj:``None``)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
The list of keys to determine which modules you want to modify. OpenDelta will take every modulees that
|
|
|
|
|
**ends with** the one of the provided keys as the modification target. When not given any value, i.e.
|
|
|
|
|
``modified_modules=None``, the delta module will use the it corresponding default modification modules.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
Taking DistilBertModel with an classifier on top as an example:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
.. note::
|
2022-02-14 21:19:03 +08:00
|
|
|
|
**Examples**: When adding delta to DistilBertModel,
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
1. set to ``["0.attention.out_lin"]`` will add delta modules to the attention output of distilbert's
|
2022-02-14 21:19:03 +08:00
|
|
|
|
ayer 0, i.e., ``distilbert.transformer.layer.0.attention.out_lin``.
|
|
|
|
|
|
|
|
|
|
2. set to ``["attention.out_lin"]`` will add the delta modules in every layer's ``attention.out_lin``.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, defaults to :obj:`["deltas"]` )
|
|
|
|
|
exclude_modules (:obj:`str`, *optional*, default to :obj:`None`): The modules starts with these strings will
|
|
|
|
|
be excluded in modification. Note that currently only plain text (no regular expression) is supported.
|
|
|
|
|
|
|
|
|
|
The modules that are unfrozen
|
|
|
|
|
during training. Including the ones that are newly introduced as delta modules, and the ones that are
|
|
|
|
|
originally a part of the model but set to trainable (:obj:`requires_grad=True`) to train together with the
|
|
|
|
|
delta modules. OpenDelta will take every modules that **ends with** the one of the provided keys and all
|
|
|
|
|
its sub-modules and paramters as trainable.
|
|
|
|
|
|
|
|
|
|
.. note::
|
2022-02-14 21:19:03 +08:00
|
|
|
|
**Examples**: When adding delta to DistilBertModel,
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
1. set this argument to ``["bias"]`` will make all bias terms tunable.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
2. set this argument to ``["attention"]`` will make all parameters in all attention modules tunable.
|
|
|
|
|
|
|
|
|
|
3. set this argument to ``["deltas"]`` will make all the parameters in the newly introduced delta
|
2022-04-14 11:22:41 +08:00
|
|
|
|
modules tunable.
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
4. set this argument to ``["classifier"]`` will make all parameters in the classifier tunable.
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
5. set this argument to ``["3.ffn.lin2", "deltas", "classifier"]``, will make all parameters in
|
2022-02-14 21:19:03 +08:00
|
|
|
|
the third layer's feed forward layer's send linear layer, the detla modules, and the classifiers modules
|
2022-04-14 11:22:41 +08:00
|
|
|
|
tunable.
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
common_structure (:obj:`bool`, *optional*, default to :obj:`None`): Whether using the common structure mapping of
|
|
|
|
|
the transformer model when designating :obj:`modified_modules` and :obj:`unfrozen_modules`.
|
|
|
|
|
backbone_class (:obj:`str`, *optional*, default to :obj:`None`): The name of backbone model's class, e.g.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
``RobertaForMaskedLM``. Saving this infomation let the users explicitly know on which backbone the
|
|
|
|
|
delta model is trained.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
backbone_checkpoint_name (:obj:`str`, *optional*, default to :obj:`None`): The specific checkpoint of the model.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
In ideal case, it should be the url to download the checkpoint. However, we do not force the user to
|
2022-02-14 21:19:03 +08:00
|
|
|
|
specify a downloadable url here.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
backbone_hash (:obj:`str`, *optional*, default to :obj:`None`): The md5-hash of the backbone model. It is
|
|
|
|
|
calculated using the string representation of the model and the sequential expansion of all the
|
|
|
|
|
parameters in the model. When loading a delta checkpoint in strict mode, the hash of the backbone model
|
|
|
|
|
will be compared to the hash in this config.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
"""
|
|
|
|
|
delta_type: str = ""
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
def __init__(self,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
modified_modules = None,
|
2022-04-14 11:22:41 +08:00
|
|
|
|
exclude_modules = None,
|
2022-02-14 21:19:03 +08:00
|
|
|
|
unfrozen_modules = ["deltas"],
|
|
|
|
|
common_structure=False,
|
|
|
|
|
backbone_class = None,
|
|
|
|
|
backbone_checkpoint_name = None,
|
|
|
|
|
backbone_hash = None,
|
|
|
|
|
):
|
|
|
|
|
arg_names = get_arg_names(BaseDeltaConfig.__init__)
|
|
|
|
|
for arg_name in arg_names:
|
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
@classmethod
|
2022-07-03 10:10:18 +08:00
|
|
|
|
def from_finetuned(cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs) -> "BaseDeltaConfig":
|
2022-02-14 21:19:03 +08:00
|
|
|
|
r"""
|
|
|
|
|
Instantiate a :obj:`BaseDeltaConfig` (or a derived class) from a finetined delta module configuration.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
finetuned_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): This can be either:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
* a string, the *model id* of a finetuned delta model configuration hosted inside a model repo on
|
|
|
|
|
deltahub.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
|
|
|
|
|
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
|
|
|
|
|
|
|
|
|
* a path to a *directory* containing a configuration file saved using the :meth:`BaseDeltaConfig.save_finetuned` method, e.g., ``./my_model_directory/``.
|
|
|
|
|
|
|
|
|
|
* a path or url to a saved configuration JSON *file*, e.g., ``./my_model_directory/configuration.json``.
|
|
|
|
|
|
|
|
|
|
cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
|
|
|
|
|
Path to a directory in which a downloaded pretrained delta model configuration should be cached if the
|
|
|
|
|
standard cache should not be used.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
delta_config = LoraConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
|
|
|
|
|
|
|
|
|
|
"""
|
2022-07-03 10:10:18 +08:00
|
|
|
|
config_dict, kwargs = cls.get_config_dict(finetuned_delta_path, **kwargs)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
|
|
|
|
logger.warn(
|
|
|
|
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
|
|
|
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return cls.from_dict(config_dict, **kwargs)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def save_finetuned(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Save a configuration object to the directory :obj:`save_directory`, so that it can be re-loaded using the
|
|
|
|
|
:meth:`BaseDeltaConfig.from_finetuned` class method.
|
|
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
save_directory (:obj:`str` or :obj:`os.PathLike`): Directory where the configuration JSON file
|
2022-02-14 21:19:03 +08:00
|
|
|
|
will be saved (will be created if it does not exist).
|
2022-04-14 11:22:41 +08:00
|
|
|
|
push_to_hub (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether or not to push your model to
|
2022-02-14 21:19:03 +08:00
|
|
|
|
the Hugging Face model hub after saving it.
|
|
|
|
|
|
|
|
|
|
.. warning::
|
|
|
|
|
1. Will raise error if you haven't config a Huggingface Model Hub.
|
|
|
|
|
2. Using ``push_to_hub=True`` will synchronize the repository you are pushing to with ``save_directory``,
|
|
|
|
|
which requires ``save_directory`` to be a local clone of the repo you are pushing to if it's an existing
|
|
|
|
|
folder. Pass along ``temp_dir=True`` to use a temporary directory instead.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
kwargs:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
Additional key word arguments passed along to the
|
2022-02-14 21:19:03 +08:00
|
|
|
|
`PushToHubMixin.push_to_hub <https://huggingface.co/docs/transformers/master/main_classes/model#transformers.file_utils.PushToHubMixin.push_to_hub>`_ method.
|
|
|
|
|
"""
|
|
|
|
|
if os.path.isfile(save_directory):
|
|
|
|
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
commit_message = kwargs.pop("commit_message", None)
|
|
|
|
|
repo = self._create_or_get_repo(save_directory, **kwargs)
|
|
|
|
|
|
|
|
|
|
os.makedirs(save_directory, exist_ok=True)
|
|
|
|
|
# If we save using the predefined names, we can load using `from_pretrained`
|
|
|
|
|
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
|
|
|
|
|
|
|
|
|
self.to_json_file(output_config_file, use_diff=True)
|
|
|
|
|
logger.info(f"Configuration saved in {output_config_file}")
|
|
|
|
|
|
|
|
|
|
if push_to_hub:
|
|
|
|
|
url = self._push_to_hub(repo, commit_message=commit_message)
|
|
|
|
|
logger.info(f"Configuration pushed to the hub in this commit: {url}")
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
@classmethod
|
|
|
|
|
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "BaseDeltaConfig":
|
|
|
|
|
r"""
|
|
|
|
|
Instantiate a :obj:`BaseDeltaConfig` from a python dictionary of parameters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
config_dict (:obj:`Dict[str, Any]`):
|
|
|
|
|
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
|
|
|
|
retrieved from a pretrained checkpoint by leveraging the :py:meth:`~PretrainedConfig.get_config_dict` method.
|
|
|
|
|
kwargs (:obj:`Dict[str, Any]`):
|
|
|
|
|
Additional parameters from which to initialize the configuration object.
|
|
|
|
|
Returns:
|
|
|
|
|
:obj:`BaseDeltaConfig`: The configuration object instantiated from those parameters.
|
|
|
|
|
"""
|
|
|
|
|
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
|
|
|
|
accept_args = get_arg_names(cls.__init__) + get_arg_names(BaseDeltaConfig.__init__)
|
|
|
|
|
unused_config_keys = []
|
|
|
|
|
for config_key in list(config_dict.keys()):
|
|
|
|
|
if config_key not in accept_args:
|
|
|
|
|
config_dict.pop(config_key)
|
|
|
|
|
unused_config_keys.append(config_key)
|
|
|
|
|
logger.warning(f"The following keys are not used by {cls}.__init__ function: {unused_config_keys}")
|
2022-07-06 22:00:58 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
config = cls(**config_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Update config with kwargs if needed
|
|
|
|
|
to_remove = []
|
|
|
|
|
for key, value in kwargs.items():
|
|
|
|
|
if hasattr(config, key):
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
setattr(config, key, value)
|
|
|
|
|
if key != "torch_dtype":
|
|
|
|
|
to_remove.append(key)
|
|
|
|
|
for key in to_remove:
|
|
|
|
|
kwargs.pop(key, None)
|
2022-07-03 10:10:18 +08:00
|
|
|
|
logger.info(f"Model config\n{config}")
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
if return_unused_kwargs:
|
|
|
|
|
return config, kwargs
|
|
|
|
|
else:
|
|
|
|
|
return config
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
@classmethod
|
|
|
|
|
def get_config_dict(
|
2022-07-03 10:10:18 +08:00
|
|
|
|
cls, finetuned_delta_path: Union[str, os.PathLike], **kwargs
|
2022-02-14 21:19:03 +08:00
|
|
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
|
|
|
"""[NODOC]
|
2022-07-03 10:10:18 +08:00
|
|
|
|
From a ``finetuned_delta_path``, resolve to a dictionary of parameters, to be used for instantiating a
|
2022-02-14 21:19:03 +08:00
|
|
|
|
[``PretrainedConfig``] using ``from_dict``.
|
|
|
|
|
Parameters:
|
2022-07-03 10:10:18 +08:00
|
|
|
|
finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`):
|
2022-02-14 21:19:03 +08:00
|
|
|
|
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
|
|
|
|
Returns:
|
|
|
|
|
:obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
|
|
|
|
"""
|
2022-07-03 10:10:18 +08:00
|
|
|
|
cache_dir = kwargs.get("cache_dir", None)
|
|
|
|
|
force_download = kwargs.get("force_download", False)
|
|
|
|
|
# resume_download = kwargs.pop("resume_download", False)
|
|
|
|
|
# proxies = kwargs.pop("proxies", None)
|
|
|
|
|
# use_auth_token = kwargs.pop("use_auth_token", None)
|
|
|
|
|
local_files_only = kwargs.get("local_files_only", False)
|
|
|
|
|
# revision = kwargs.pop("revision", None)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
# from_pipeline = kwargs.pop("_from_pipeline", None)
|
2022-07-03 10:10:18 +08:00
|
|
|
|
# from_auto_class = kwargs.pop("_from_auto", False)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
# user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
|
2022-02-14 21:19:03 +08:00
|
|
|
|
# if from_pipeline is not None:
|
|
|
|
|
# user_agent["using_pipeline"] = from_pipeline
|
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
if os.environ.get("DELTACENTER_OFFLINE", '0') == '1':
|
|
|
|
|
logger.info("Delta Center offline mode!")
|
2022-02-14 21:19:03 +08:00
|
|
|
|
local_files_only = True
|
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
finetuned_delta_path = str(finetuned_delta_path)
|
2022-07-01 22:23:02 +08:00
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
if cache_dir is not None:
|
|
|
|
|
cached_finetuned_delta_path = os.path.join(cache_dir, finetuned_delta_path)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
else:
|
2022-07-03 10:10:18 +08:00
|
|
|
|
cached_finetuned_delta_path = finetuned_delta_path
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
if os.path.isfile(cached_finetuned_delta_path):
|
|
|
|
|
local_files_only = True
|
|
|
|
|
elif os.path.isdir(cached_finetuned_delta_path):
|
|
|
|
|
# cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
|
|
|
|
|
local_files_only = True
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
# if local_files_only:
|
|
|
|
|
# config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
|
|
|
|
|
if not local_files_only or force_download:
|
|
|
|
|
from .utils.delta_center import download as dcdownload
|
|
|
|
|
# try to download from DeltaCenter
|
|
|
|
|
cached_finetuned_delta_path = dcdownload(finetuned_delta_path, force_download=force_download, cache_dir=cache_dir)
|
|
|
|
|
kwargs['force_download'] = False # Has been downloaded, not more forcing
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
2022-07-03 10:10:18 +08:00
|
|
|
|
cached_finetuned_delta_path = os.path.join(cached_finetuned_delta_path, 'config.json')
|
|
|
|
|
config_dict = cls._dict_from_json_file(cached_finetuned_delta_path)
|
2022-02-14 21:19:03 +08:00
|
|
|
|
return config_dict, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
@classmethod
|
|
|
|
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
|
|
|
|
with open(json_file, "r", encoding="utf-8") as reader:
|
|
|
|
|
text = reader.read()
|
|
|
|
|
return json.loads(text)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"{self.__class__.__name__} {self.to_json_string()}"
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
return self.__dict__ == other.__dict__
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def to_json_string(self, use_diff: bool = True) -> str:
|
|
|
|
|
"""[NODOC]
|
|
|
|
|
Serializes this instance to a JSON string.
|
|
|
|
|
Args:
|
|
|
|
|
use_diff (:obj:`bool`, *optional*, defaults to :obj:`True`):
|
|
|
|
|
If set to :obj:`True`, only the difference between the config instance and the default ``PretrainedConfig()``
|
|
|
|
|
is serialized to JSON string.
|
|
|
|
|
Returns:
|
|
|
|
|
:obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
|
|
|
|
"""
|
|
|
|
|
if use_diff is True:
|
|
|
|
|
config_dict = self.to_diff_dict()
|
|
|
|
|
else:
|
|
|
|
|
config_dict = self.to_dict()
|
|
|
|
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
|
|
|
|
|
|
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
|
|
|
|
"""[NODOC]
|
|
|
|
|
Save this instance to a JSON file.
|
|
|
|
|
Args:
|
|
|
|
|
json_file_path (:obj:`str` or :obj:`os.PathLike`):
|
|
|
|
|
Path to the JSON file in which this configuration instance's parameters will be saved.
|
|
|
|
|
use_diff (:obj:`bool`, *optional*, defaults to :obj:`True`):
|
|
|
|
|
If set to :obj:`True`, only the difference between the config instance and the default ``PretrainedConfig()``
|
|
|
|
|
is serialized to JSON file.
|
|
|
|
|
"""
|
|
|
|
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
|
|
|
|
writer.write(self.to_json_string(use_diff=use_diff))
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def to_diff_dict(self) -> Dict[str, Any]:
|
|
|
|
|
"""[NODOC]
|
|
|
|
|
Removes all attributes from config which correspond to the default config attributes for better readability and
|
|
|
|
|
serializes to a Python dictionary.
|
|
|
|
|
Returns:
|
|
|
|
|
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
|
|
|
|
"""
|
|
|
|
|
config_dict = self.to_dict()
|
|
|
|
|
|
|
|
|
|
# get the default config dict
|
|
|
|
|
default_config_dict = BaseDeltaConfig().to_dict()
|
|
|
|
|
|
|
|
|
|
# get class specific config dict
|
|
|
|
|
class_config_dict = self.__class__().to_dict() #if not self.is_composition else {}
|
|
|
|
|
|
|
|
|
|
serializable_config_dict = {}
|
|
|
|
|
|
|
|
|
|
# only serialize values that differ from the default config
|
|
|
|
|
for key, value in config_dict.items():
|
|
|
|
|
if (
|
|
|
|
|
key not in default_config_dict
|
|
|
|
|
or key in checked_package_versions
|
|
|
|
|
or value != default_config_dict[key]
|
|
|
|
|
or (key in class_config_dict and value != class_config_dict[key])
|
|
|
|
|
):
|
|
|
|
|
serializable_config_dict[key] = value
|
|
|
|
|
|
|
|
|
|
self.dict_torch_dtype_to_str(serializable_config_dict)
|
|
|
|
|
|
|
|
|
|
return serializable_config_dict
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def update(self, config_dict: Dict[str, Any]):
|
|
|
|
|
"""[NODOC]
|
|
|
|
|
Updates attributes of this class with attributes from ``config_dict``.
|
|
|
|
|
Args:
|
|
|
|
|
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
|
|
|
|
"""
|
|
|
|
|
for key, value in config_dict.items():
|
|
|
|
|
setattr(self, key, value)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
|
|
|
"""
|
|
|
|
|
Serializes this instance to a Python dictionary.
|
|
|
|
|
Returns:
|
|
|
|
|
:obj:`dict`: Dictionary of all the attributes that make up this configuration instance.
|
|
|
|
|
"""
|
|
|
|
|
output = copy.deepcopy(self.__dict__)
|
|
|
|
|
if hasattr(self.__class__, "model_type"):
|
|
|
|
|
output["model_type"] = self.__class__.model_type
|
|
|
|
|
|
|
|
|
|
# Transformers version when serializing the model
|
|
|
|
|
output["transformers_version"] = transformers_version
|
|
|
|
|
output["opendelta_version"] = opendelta_version
|
|
|
|
|
|
|
|
|
|
self.dict_torch_dtype_to_str(output)
|
|
|
|
|
|
|
|
|
|
return output
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
|
|
|
|
"""[NODOC]
|
|
|
|
|
Checks whether the passed dictionary has a *torch_dtype* key and if it's not None, converts torch.dtype to a
|
|
|
|
|
string of just the type. For example, ``torch.float32`` get converted into *"float32"* string, which can then be
|
|
|
|
|
stored in the json format.
|
|
|
|
|
"""
|
|
|
|
|
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
|
|
|
|
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
myconfig = BaseDeltaConfig.from_pretrained("../ckpts/lora/")
|
|
|
|
|
myconfig.save_pretrained("../ckpts/lora.1/")
|
|
|
|
|
print(myconfig)
|