merge parallel adapter

This commit is contained in:
shengdinghu 2022-10-17 06:46:30 +00:00
commit 26e45110b2
3 changed files with 100 additions and 92 deletions

View File

@ -26,6 +26,7 @@ from opendelta.utils.data_parallel import new_replicate_for_data_parallel
from opendelta.utils.cuda import move_dict_to_cuda
import sys
from opendelta.utils.data_parallel import caller_map
logger = logging.get_logger(__name__)
def is_leaf_module(module):
@ -531,7 +532,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
raise NotImplementedError
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
def insert_module(self, module, method='sequential', delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
@ -547,15 +548,6 @@ class DeltaBase(nn.Module, SaveLoadMixin):
original delta is passed through ``_delta_info``.
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):# is not None:
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):# is not None:
ret = delta_module.post_forward(ret)
return ret
if strict:
@ -566,9 +558,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_sequential",
"delta_module": delta_module,
_delta_info = {"method": method,
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
@ -580,12 +572,36 @@ class DeltaBase(nn.Module, SaveLoadMixin):
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
if _delta_info["method"] in caller_map.keys():
caller = caller_map[_delta_info["method"]]
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
else:
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function.
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
original delta is passed through ``_delta_info``.
"""
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
@ -604,41 +620,8 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_parallel",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
def set_active_state_dict(self, module: nn.Module):
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.

View File

@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module):
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
# TODO:
# If we want to have a layer norm on output, we apply it later after a separate residual connection
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
# if self.add_layer_norm_after:
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module):
self.instantiate(hidden_dim=self.hidden_dim)
self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens?
self.adapter_output = self.modulelist(hiddens) * self.scaled
return args, kwargs
def post_forward(self, *args, **kwargs):
if isinstance(args, tuple):
output = args[0]
elif isinstance(args, torch.Tensor):
output = args
def post_forward(self, output, **kwargs):
if isinstance(output, tuple):
hidden = output[0]
elif isinstance(output, torch.Tensor):
hidden = output
else:
raise TypeError
modified_output = self.adapter_output + output
return modified_output
modified_output = self.adapter_output + hidden
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = modified_output
else:
raise TypeError
return output
@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase):
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
non_linearity (:obj:`str`): The non linearity of the adapter.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase):
_, _, ref = self.find_module(module, key)
if self.ith % 2 == 0:
adapterlayer = self.new_module_like(ref)
self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
else:
adapterlayer = self.delta_moduels[-1]
self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter")
if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]:
adapterlayer = self.delta_modules[-1]
self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter")
self.ith |= 1
self.ith += 1
self.ith %= len(self.modified_modules)
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -4,26 +4,50 @@
from opendelta.utils.decorate import decorate
from collections import OrderedDict
def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def before_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
return ret
def after_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
caller_map = {
"sequential": sequential_caller,
"parallel": parallel_caller,
"before": before_caller,
"after": after_caller,
}
def new_replicate_for_data_parallel(self):
r""" self is the parent module.
"""
# rewrite the replicate in DataParallel.
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
replica = self.__new__(type(self))
org_forward = replica.forward
replica.__dict__ = self.__dict__.copy()
@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self):
for _delta_info in self._delta_infos:
if _delta_info['state'] == 'on':
if _delta_info['method'] == "insert_sequential":
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
elif _delta_info['method'] == "insert_parallel":
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
if _delta_info['method'] in caller_map.keys():
caller = caller_map[_delta_info['method']]
new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
else:
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))