This commit is contained in:
Achazwl 2022-05-09 15:17:41 +08:00
parent 351aa3a40d
commit f8db5be89b
2 changed files with 157 additions and 41 deletions

View File

@ -1,5 +1,7 @@
from functools import partial
from hashlib import sha1
from random import random
from sqlite3 import adapters
from typing import Optional, Union
from cv2 import accumulate
@ -20,19 +22,30 @@ from opendelta import global_setting
logger = logging.get_logger(__name__)
from itertools import accumulate
from opendelta.delta_models.adapter import AdapterLayer
from opendelta.delta_models.lora import LowRankLinear
class SplitLayer(nn.Module):
class _SplitLayer(nn.Module):
r"""A layer of splitting module.
"""
def __init__(self, batch_size:list):
def __init__(self):
super().__init__()
self.batch_size = list(accumulate(batch_size))
self.modulelist = nn.ModuleList()
self.pseudo_inited = False
self.module_dict = nn.ModuleDict()
def append(self, module):
self.modulelist.append(module)
def attach(self, module_name: str, module: nn.Module):
if module_name in self.module_dict:
return False
self.module_dict[module_name] = module
return True
def detach(self, module_name: str):
if module_name not in self.module_dict:
return False
self.module_dict.pop(module_name)
return True
class SplitSequentialLayer(_SplitLayer):
def __init__(self):
super().__init__()
def post_forward(self, output):
if isinstance(output, tuple):
@ -41,23 +54,14 @@ class SplitLayer(nn.Module):
hiddens = output
else:
raise TypeError
if hiddens.shape[0] != self.batch_size[-1]:
if self.pseudo_inited:
raise RuntimeError('The batch size of the input is not consistent with split config.')
self.pseudo_inited = True
outputs = None
for i in range(len(self.batch_size)):
outputs = self.modulelist[i].post_forward(
hiddens
)
merge_output = outputs
else:
split_outputs = [None]*len(self.batch_size)
for i in range(len(self.batch_size)):
split_outputs[i] = self.modulelist[i].post_forward(
hiddens[(0 if i==0 else self.batch_size[i-1]):self.batch_size[i]]
)
merge_output = torch.cat(split_outputs)
split_outputs = []
for module_name, module in self.module_dict.items():
split_outputs.append( module.post_forward(
hiddens
) )
print(len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
if isinstance(output, tuple):
output = (merge_output,) + output[1:]
@ -67,6 +71,22 @@ class SplitLayer(nn.Module):
raise TypeError
return output
class SplitParallelLayer(_SplitLayer):
r"""A layer of splitting module.
"""
def __init__(self):
super().__init__()
def forward(self, hiddens):
split_outputs = []
for module_name, module in self.module_dict.items():
split_outputs.append( module(
hiddens
) )
print(len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
return merge_output
class SplitConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
@ -74,7 +94,6 @@ class SplitConfig(BaseDeltaConfig):
"""
def __init__(
self,
batch_size: list = [8, 1, 7],
**kwargs
):
super().__init__(**kwargs)
@ -98,7 +117,7 @@ class SplitModel(DeltaBase):
class attributes:
- default_modified_modules = ["attn", "ff"] According to the Adapter paper, we add adapter to the attention layer
and feed forward layer.
- delta_type = "adapter"
- delta_type = "batch_split"
Args:
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
@ -115,12 +134,11 @@ class SplitModel(DeltaBase):
"""
config_class = SplitConfig
delta_type = "adapter"
delta_type = "batch_split"
default_modified_modules = ["attn", "ff"]
def __init__(self,
backbone_model: nn.Module,
batch_size: list = [8, 1, 7],
modified_modules: Optional[List[str]] = None,
modified_modules: Optional[List[str]] = [],
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
@ -145,15 +163,102 @@ class SplitModel(DeltaBase):
self.modified_modules,
)
def update_module(self, module: nn.Module, key: str):
_, _, ref = self.find_module(module, key)
splitlayer = SplitLayer(self.batch_size)
for b in self.batch_size:
splitlayer.append(self.new_module_like(ref))
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split")
self.modified_points = {}
def new_module_like(self, module):
module_device = get_device(module)
def add_all_delta_to_backbone(self,
backbone: nn.Module,
modified_modules: List[str],
) -> nn.Module:
r"""The main function to add delta models to the backbone model based on the :obj:`modified_modules`.
Args:
backbone_model (:obj:`nn.Module`, *required*) backbone model that the delta models are build opon. The
modification to the backbone model are in place.
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`) The modules are subjected to update.
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
addressing the corresponding modules.
Returns:
:obj:`nn.Module` The modified backbone model.
"""
self.plm_total_params = sum(p.numel() for p in backbone.parameters())
# create a new key list to avoid recursion.
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
return backbone
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None):
if isinstance(modified_point, str):
modified_point = [modified_point]
if module_type not in ["adapter", "lora"]:
raise ValueError("module_type must be either adapter or lora")
for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
if key not in self.modified_points:
if module_type == "adapter":
splitlayer = SplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
if module_type == "lora":
splitlayer = SplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
self.modified_points[key] = splitlayer
splitlayer = self.modified_points[key]
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)):
raise ValueError("one modified_point can have at most one module_type")
if module is None:
if module_type == "adapter":
module = self.new_adapter_like(ref)
if module_type == "lora":
module = self.new_lora_like(ref)
if not splitlayer.attach(module_name, module):
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
def update(self):
self._pseudo_data_to_instantiate(self.backbone_model)
self.mark_as_delta()
def detach(self, modified_point: str, module_name: str):
if isinstance(modified_point, str):
modified_point = [modified_point]
for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
if key not in self.modified_points:
raise ValueError("no module has been added to {}".format(key))
splitlayer = self.modified_points[key]
if not splitlayer.detach(module_name):
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
def new_adapter_like(self, module):
adapterlayer = AdapterLayer()
self.delta_modules.append(adapterlayer)
return adapterlayer
def new_lora_like(self, child_module):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = LowRankLinear(in_features = in_features,
out_features = out_features,
weight = child_module.weight,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
self.delta_modules.append(new_module)
else:
raise NotImplementedError
return new_module

17
test.py
View File

@ -3,13 +3,24 @@ model = BertModel.from_pretrained("bert-base-cased")
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import SplitModel
delta_model = SplitModel(model, batch_size=[1]*16, modified_modules=['output.dense'])
delta_model = SplitModel(model)
print("attach here")
delta_model.attach("output.dense", "adapter_A", "adapter")
delta_model.attach("output.dense", "adapter_B", "adapter")
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
delta_model.update()
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
delta_model.log() # This will visualize the backbone after modification and other information.
import torch
x = torch.randint(0, 10, (16, 128)).cuda()
import time
model = model.cuda()
st_time = time.time()
for t in range(10):
y = model(x)
print("run here")
y = model(x)
print("detach here")
delta_model.detach("3.output.dense", "adapter_A")
delta_model.log()
print("run here")
y = model(x)
print(time.time() - st_time)