111 lines
3.8 KiB
Python
111 lines
3.8 KiB
Python
|
|
|
|
import importlib
|
|
|
|
|
|
class BackendMapping:
|
|
"""
|
|
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
|
|
|
|
Args:
|
|
|
|
- config_mapping: The map model type to config class
|
|
- model_mapping: The map model type to model (or tokenizer) class
|
|
"""
|
|
|
|
def __init__(self, backend):
|
|
self.backend = backend
|
|
assert backend in ['hf', 'bmt'], "Backend should be one of 'hf', 'bmt'. "
|
|
if backend == 'hf':
|
|
self.backend_mapping = {
|
|
"linear": "torch.nn.Linear",
|
|
"layer_norm": "torch.nn.LayerNorm",
|
|
"module": "torch.nn.Module",
|
|
"parameter": "torch.nn.Parameter"
|
|
}
|
|
elif backend == 'bmt':
|
|
self.backend_mapping = {
|
|
"linear": "model_center.layer.Linear",
|
|
"layer_norm": "model_center.layer.LayerNorm",
|
|
"module": "bmtrain.layer.DistributedModule",
|
|
"parameter": "bmtrain.nn.DistributedParameter"
|
|
}
|
|
self.registered = {}
|
|
|
|
def load(self, model_type):
|
|
if model_type not in self.registered:
|
|
splited = self.backend_mapping[model_type].split(".")
|
|
module_name, class_name = ".".join(splited[:-1]), splited[-1]
|
|
module = importlib.import_module(module_name)
|
|
the_class = getattr(module, class_name)
|
|
self.registered[model_type] = the_class
|
|
return self.registered[model_type]
|
|
|
|
def check_type(self, module, expect_type):
|
|
the_class = self.load(expect_type)
|
|
if isinstance(module, the_class):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
# def keys(self):
|
|
# mapping_keys = [
|
|
# self._load_attr_from_module(key, name)
|
|
# for key, name in self._config_mapping.items()
|
|
# if key in self._model_mapping.keys()
|
|
# ]
|
|
# return mapping_keys + list(self._extra_content.keys())
|
|
|
|
# def get(self, key, default):
|
|
# try:
|
|
# return self.__getitem__(key)
|
|
# except KeyError:
|
|
# return default
|
|
|
|
# def __bool__(self):
|
|
# return bool(self.keys())
|
|
|
|
# def values(self):
|
|
# mapping_values = [
|
|
# self._load_attr_from_module(key, name)
|
|
# for key, name in self._model_mapping.items()
|
|
# if key in self._config_mapping.keys()
|
|
# ]
|
|
# return mapping_values + list(self._extra_content.values())
|
|
|
|
# def items(self):
|
|
# mapping_items = [
|
|
# (
|
|
# self._load_attr_from_module(key, self._config_mapping[key]),
|
|
# self._load_attr_from_module(key, self._model_mapping[key]),
|
|
# )
|
|
# for key in self._model_mapping.keys()
|
|
# if key in self._config_mapping.keys()
|
|
# ]
|
|
# return mapping_items + list(self._extra_content.items())
|
|
|
|
# def __iter__(self):
|
|
# return iter(self.keys())
|
|
|
|
# def __contains__(self, item):
|
|
# if item in self._extra_content:
|
|
# return True
|
|
# if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
|
|
# return False
|
|
# model_type = self._reverse_config_mapping[item.__name__]
|
|
# return model_type in self._model_mapping
|
|
|
|
# def register(self, key, value):
|
|
# """
|
|
# Register a new model in this mapping.
|
|
# """
|
|
# if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
|
|
# model_type = self._reverse_config_mapping[key.__name__]
|
|
# if model_type in self._model_mapping.keys():
|
|
# raise ValueError(f"'{key}' is already used by a Transformers model.")
|
|
|
|
# self._extra_content[key] = value
|
|
|
|
|