OpenDeltaMirror/opendelta/utils/backend.py

111 lines
3.8 KiB
Python
Raw Permalink Normal View History

2022-10-23 16:42:21 +08:00
import importlib
class BackendMapping:
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Args:
- config_mapping: The map model type to config class
- model_mapping: The map model type to model (or tokenizer) class
"""
def __init__(self, backend):
self.backend = backend
assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. "
if backend == 'hf':
self.backend_mapping = {
"linear": "torch.nn.Linear",
"layer_norm": "torch.nn.LayerNorm",
"module": "torch.nn.Module",
"parameter": "torch.nn.Parameter"
}
elif backend == 'bmt':
self.backend_mapping = {
"linear": "model_center.layer.Linear",
"layer_norm": "model_center.layer.LayerNorm",
"module": "bmtrain.layer.DistributedModule",
"parameter": "bmtrain.nn.DistributedParameter"
}
self.registered = {}
def load(self, model_type):
if model_type not in self.registered:
splited = self.backend_mapping[model_type].split(".")
module_name, class_name = ".".join(splited[:-1]), splited[-1]
module = importlib.import_module(module_name)
the_class = getattr(module, class_name)
self.registered[model_type] = the_class
return self.registered[model_type]
def check_type(self, module, expect_type):
the_class = self.load(expect_type)
if isinstance(module, the_class):
return True
else:
return False
# def keys(self):
# mapping_keys = [
# self._load_attr_from_module(key, name)
# for key, name in self._config_mapping.items()
# if key in self._model_mapping.keys()
# ]
# return mapping_keys + list(self._extra_content.keys())
# def get(self, key, default):
# try:
# return self.__getitem__(key)
# except KeyError:
# return default
# def __bool__(self):
# return bool(self.keys())
# def values(self):
# mapping_values = [
# self._load_attr_from_module(key, name)
# for key, name in self._model_mapping.items()
# if key in self._config_mapping.keys()
# ]
# return mapping_values + list(self._extra_content.values())
# def items(self):
# mapping_items = [
# (
# self._load_attr_from_module(key, self._config_mapping[key]),
# self._load_attr_from_module(key, self._model_mapping[key]),
# )
# for key in self._model_mapping.keys()
# if key in self._config_mapping.keys()
# ]
# return mapping_items + list(self._extra_content.items())
# def __iter__(self):
# return iter(self.keys())
# def __contains__(self, item):
# if item in self._extra_content:
# return True
# if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
# return False
# model_type = self._reverse_config_mapping[item.__name__]
# return model_type in self._model_mapping
# def register(self, key, value):
# """
# Register a new model in this mapping.
# """
# if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
# model_type = self._reverse_config_mapping[key.__name__]
# if model_type in self._model_mapping.keys():
# raise ValueError(f"'{key}' is already used by a Transformers model.")
# self._extra_content[key] = value