dynamic
This commit is contained in:
parent
351aa3a40d
commit
f8db5be89b
|
@ -1,5 +1,7 @@
|
|||
from functools import partial
|
||||
from hashlib import sha1
|
||||
from random import random
|
||||
from sqlite3 import adapters
|
||||
from typing import Optional, Union
|
||||
|
||||
from cv2 import accumulate
|
||||
|
@ -20,19 +22,30 @@ from opendelta import global_setting
|
|||
logger = logging.get_logger(__name__)
|
||||
from itertools import accumulate
|
||||
from opendelta.delta_models.adapter import AdapterLayer
|
||||
from opendelta.delta_models.lora import LowRankLinear
|
||||
|
||||
|
||||
class SplitLayer(nn.Module):
|
||||
class _SplitLayer(nn.Module):
|
||||
r"""A layer of splitting module.
|
||||
"""
|
||||
def __init__(self, batch_size:list):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.batch_size = list(accumulate(batch_size))
|
||||
self.modulelist = nn.ModuleList()
|
||||
self.pseudo_inited = False
|
||||
self.module_dict = nn.ModuleDict()
|
||||
|
||||
def append(self, module):
|
||||
self.modulelist.append(module)
|
||||
def attach(self, module_name: str, module: nn.Module):
|
||||
if module_name in self.module_dict:
|
||||
return False
|
||||
self.module_dict[module_name] = module
|
||||
return True
|
||||
|
||||
def detach(self, module_name: str):
|
||||
if module_name not in self.module_dict:
|
||||
return False
|
||||
self.module_dict.pop(module_name)
|
||||
return True
|
||||
|
||||
class SplitSequentialLayer(_SplitLayer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def post_forward(self, output):
|
||||
if isinstance(output, tuple):
|
||||
|
@ -41,23 +54,14 @@ class SplitLayer(nn.Module):
|
|||
hiddens = output
|
||||
else:
|
||||
raise TypeError
|
||||
if hiddens.shape[0] != self.batch_size[-1]:
|
||||
if self.pseudo_inited:
|
||||
raise RuntimeError('The batch size of the input is not consistent with split config.')
|
||||
self.pseudo_inited = True
|
||||
outputs = None
|
||||
for i in range(len(self.batch_size)):
|
||||
outputs = self.modulelist[i].post_forward(
|
||||
hiddens
|
||||
)
|
||||
merge_output = outputs
|
||||
else:
|
||||
split_outputs = [None]*len(self.batch_size)
|
||||
for i in range(len(self.batch_size)):
|
||||
split_outputs[i] = self.modulelist[i].post_forward(
|
||||
hiddens[(0 if i==0 else self.batch_size[i-1]):self.batch_size[i]]
|
||||
)
|
||||
merge_output = torch.cat(split_outputs)
|
||||
|
||||
split_outputs = []
|
||||
for module_name, module in self.module_dict.items():
|
||||
split_outputs.append( module.post_forward(
|
||||
hiddens
|
||||
) )
|
||||
print(len(split_outputs))
|
||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (merge_output,) + output[1:]
|
||||
|
@ -67,6 +71,22 @@ class SplitLayer(nn.Module):
|
|||
raise TypeError
|
||||
return output
|
||||
|
||||
class SplitParallelLayer(_SplitLayer):
|
||||
r"""A layer of splitting module.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hiddens):
|
||||
split_outputs = []
|
||||
for module_name, module in self.module_dict.items():
|
||||
split_outputs.append( module(
|
||||
hiddens
|
||||
) )
|
||||
print(len(split_outputs))
|
||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||
return merge_output
|
||||
|
||||
class SplitConfig(BaseDeltaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
|
||||
|
@ -74,7 +94,6 @@ class SplitConfig(BaseDeltaConfig):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: list = [8, 1, 7],
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -98,7 +117,7 @@ class SplitModel(DeltaBase):
|
|||
class attributes:
|
||||
- default_modified_modules = ["attn", "ff"] According to the Adapter paper, we add adapter to the attention layer
|
||||
and feed forward layer.
|
||||
- delta_type = "adapter"
|
||||
- delta_type = "batch_split"
|
||||
|
||||
Args:
|
||||
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
||||
|
@ -115,12 +134,11 @@ class SplitModel(DeltaBase):
|
|||
|
||||
"""
|
||||
config_class = SplitConfig
|
||||
delta_type = "adapter"
|
||||
delta_type = "batch_split"
|
||||
default_modified_modules = ["attn", "ff"]
|
||||
def __init__(self,
|
||||
backbone_model: nn.Module,
|
||||
batch_size: list = [8, 1, 7],
|
||||
modified_modules: Optional[List[str]] = None,
|
||||
modified_modules: Optional[List[str]] = [],
|
||||
exclude_modules: Optional[List[str]] = None,
|
||||
unfrozen_modules: Optional[List[str]] = None,
|
||||
common_structure: Optional[bool] = None,
|
||||
|
@ -145,15 +163,102 @@ class SplitModel(DeltaBase):
|
|||
self.modified_modules,
|
||||
)
|
||||
|
||||
def update_module(self, module: nn.Module, key: str):
|
||||
_, _, ref = self.find_module(module, key)
|
||||
splitlayer = SplitLayer(self.batch_size)
|
||||
for b in self.batch_size:
|
||||
splitlayer.append(self.new_module_like(ref))
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split")
|
||||
self.modified_points = {}
|
||||
|
||||
def new_module_like(self, module):
|
||||
module_device = get_device(module)
|
||||
def add_all_delta_to_backbone(self,
|
||||
backbone: nn.Module,
|
||||
modified_modules: List[str],
|
||||
) -> nn.Module:
|
||||
r"""The main function to add delta models to the backbone model based on the :obj:`modified_modules`.
|
||||
|
||||
Args:
|
||||
backbone_model (:obj:`nn.Module`, *required*) backbone model that the delta models are build opon. The
|
||||
modification to the backbone model are in place.
|
||||
modified_modules (:obj:`List[str]`, *optional*, default to :obj:`None`) The modules are subjected to update.
|
||||
leave this argument :obj:`None` will make the delta model return to the default setting, which add the delta
|
||||
models to the position experimented the paper. In this setting, the common structure mapping is loaded to
|
||||
addressing the corresponding modules.
|
||||
|
||||
Returns:
|
||||
:obj:`nn.Module` The modified backbone model.
|
||||
|
||||
"""
|
||||
self.plm_total_params = sum(p.numel() for p in backbone.parameters())
|
||||
# create a new key list to avoid recursion.
|
||||
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
|
||||
return backbone
|
||||
|
||||
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None):
|
||||
if isinstance(modified_point, str):
|
||||
modified_point = [modified_point]
|
||||
|
||||
if module_type not in ["adapter", "lora"]:
|
||||
raise ValueError("module_type must be either adapter or lora")
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
logger.debug("find key: {}".format(key))
|
||||
_, _, ref = self.find_module(self.backbone_model, key)
|
||||
|
||||
if key not in self.modified_points:
|
||||
if module_type == "adapter":
|
||||
splitlayer = SplitSequentialLayer()
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
||||
if module_type == "lora":
|
||||
splitlayer = SplitParallelLayer()
|
||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
||||
self.modified_points[key] = splitlayer
|
||||
|
||||
splitlayer = self.modified_points[key]
|
||||
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
||||
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)):
|
||||
raise ValueError("one modified_point can have at most one module_type")
|
||||
|
||||
if module is None:
|
||||
if module_type == "adapter":
|
||||
module = self.new_adapter_like(ref)
|
||||
if module_type == "lora":
|
||||
module = self.new_lora_like(ref)
|
||||
|
||||
if not splitlayer.attach(module_name, module):
|
||||
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
def update(self):
|
||||
self._pseudo_data_to_instantiate(self.backbone_model)
|
||||
self.mark_as_delta()
|
||||
|
||||
def detach(self, modified_point: str, module_name: str):
|
||||
if isinstance(modified_point, str):
|
||||
modified_point = [modified_point]
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
logger.debug("find key: {}".format(key))
|
||||
_, _, ref = self.find_module(self.backbone_model, key)
|
||||
|
||||
if key not in self.modified_points:
|
||||
raise ValueError("no module has been added to {}".format(key))
|
||||
|
||||
splitlayer = self.modified_points[key]
|
||||
|
||||
if not splitlayer.detach(module_name):
|
||||
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
def new_adapter_like(self, module):
|
||||
adapterlayer = AdapterLayer()
|
||||
self.delta_modules.append(adapterlayer)
|
||||
return adapterlayer
|
||||
|
||||
def new_lora_like(self, child_module):
|
||||
if isinstance(child_module, nn.Linear):
|
||||
in_features, out_features = child_module.in_features, child_module.out_features
|
||||
new_module = LowRankLinear(in_features = in_features,
|
||||
out_features = out_features,
|
||||
weight = child_module.weight,
|
||||
r=self.lora_r,
|
||||
lora_alpha=self.lora_alpha,
|
||||
lora_dropout=self.lora_dropout)
|
||||
self.delta_modules.append(new_module)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return new_module
|
||||
|
|
17
test.py
17
test.py
|
@ -3,13 +3,24 @@ model = BertModel.from_pretrained("bert-base-cased")
|
|||
from opendelta import Visualization
|
||||
Visualization(model).structure_graph()
|
||||
from opendelta import SplitModel
|
||||
delta_model = SplitModel(model, batch_size=[1]*16, modified_modules=['output.dense'])
|
||||
delta_model = SplitModel(model)
|
||||
print("attach here")
|
||||
delta_model.attach("output.dense", "adapter_A", "adapter")
|
||||
delta_model.attach("output.dense", "adapter_B", "adapter")
|
||||
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
|
||||
delta_model.update()
|
||||
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
|
||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
import torch
|
||||
x = torch.randint(0, 10, (16, 128)).cuda()
|
||||
import time
|
||||
model = model.cuda()
|
||||
st_time = time.time()
|
||||
for t in range(10):
|
||||
y = model(x)
|
||||
print("run here")
|
||||
y = model(x)
|
||||
print("detach here")
|
||||
delta_model.detach("3.output.dense", "adapter_A")
|
||||
delta_model.log()
|
||||
print("run here")
|
||||
y = model(x)
|
||||
print(time.time() - st_time)
|
Loading…
Reference in New Issue