merge parallel adapter
This commit is contained in:
commit
26e45110b2
|
@ -26,6 +26,7 @@ from opendelta.utils.data_parallel import new_replicate_for_data_parallel
|
|||
from opendelta.utils.cuda import move_dict_to_cuda
|
||||
import sys
|
||||
|
||||
from opendelta.utils.data_parallel import caller_map
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def is_leaf_module(module):
|
||||
|
@ -531,7 +532,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||
def insert_module(self, module, method='sequential', delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
||||
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
||||
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
||||
|
@ -547,15 +548,6 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
original delta is passed through ``_delta_info``.
|
||||
|
||||
"""
|
||||
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
if hasattr(delta_module, "pre_forward"):# is not None:
|
||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||
ret = _org_func(*args, **kwargs)
|
||||
if hasattr(delta_module, "post_forward"):# is not None:
|
||||
ret = delta_module.post_forward(ret)
|
||||
return ret
|
||||
|
||||
|
||||
if strict:
|
||||
|
@ -566,9 +558,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
if _delta_info is None:
|
||||
if delta_module is None:
|
||||
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
||||
|
||||
_delta_info = {"method": "insert_sequential",
|
||||
"delta_module": delta_module,
|
||||
|
||||
_delta_info = {"method": method,
|
||||
"delta_module": delta_module,
|
||||
"delta_name": delta_name,
|
||||
"delta_belong": self,
|
||||
"state": "on"}
|
||||
|
@ -580,12 +572,36 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
||||
|
||||
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
||||
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
||||
# for DataParallel's copy behavior. Experimental:
|
||||
# may have bugs when module.forward is nestedly wrapped.
|
||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||
|
||||
if _delta_info["method"] in caller_map.keys():
|
||||
caller = caller_map[_delta_info["method"]]
|
||||
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
||||
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
||||
# for DataParallel's copy behavior. Experimental:
|
||||
# may have bugs when module.forward is nestedly wrapped.
|
||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||
else:
|
||||
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
|
||||
|
||||
|
||||
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
||||
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
||||
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
||||
|
||||
When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function.
|
||||
|
||||
Args:
|
||||
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
|
||||
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
|
||||
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
|
||||
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
|
||||
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
|
||||
original delta is passed through ``_delta_info``.
|
||||
|
||||
"""
|
||||
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
|
||||
|
||||
|
||||
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
|
||||
|
@ -604,41 +620,8 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
|
||||
"""
|
||||
|
||||
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
ret_1 = _org_func(*args, **kwargs)
|
||||
ret_2 = delta_module.forward(*args, **kwargs)
|
||||
return ret_1 + ret_2
|
||||
|
||||
if strict:
|
||||
if hasattr(module.forward, "__wrapped__"):
|
||||
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
|
||||
|
||||
# record info for plug and unplug and nested wrap
|
||||
if _delta_info is None:
|
||||
if delta_module is None:
|
||||
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
||||
|
||||
_delta_info = {"method": "insert_parallel",
|
||||
"delta_module": delta_module,
|
||||
"delta_name": delta_name,
|
||||
"delta_belong": self,
|
||||
"state": "on"}
|
||||
self._register_delta_infos(parent_module=module,
|
||||
_delta_info = _delta_info)
|
||||
else:
|
||||
delta_module = _delta_info["delta_module"]
|
||||
delta_name = _delta_info["delta_name"]
|
||||
|
||||
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
||||
|
||||
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
||||
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
||||
# for DataParallel's copy behavior. Experimental:
|
||||
# may have bugs when module.forward is nestedly wrapped.
|
||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||
|
||||
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
|
||||
|
||||
|
||||
def set_active_state_dict(self, module: nn.Module):
|
||||
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.
|
||||
|
|
|
@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module):
|
|||
|
||||
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
|
||||
|
||||
# TODO:
|
||||
# If we want to have a layer norm on output, we apply it later after a separate residual connection
|
||||
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
|
||||
# if self.add_layer_norm_after:
|
||||
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
|
||||
|
||||
self.instantiated = True
|
||||
# initialize the weight, which is important for fast convergence and better performance.
|
||||
self.apply(self._init_weight)
|
||||
|
@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module):
|
|||
self.instantiate(hidden_dim=self.hidden_dim)
|
||||
|
||||
|
||||
self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens?
|
||||
self.adapter_output = self.modulelist(hiddens) * self.scaled
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, *args, **kwargs):
|
||||
if isinstance(args, tuple):
|
||||
output = args[0]
|
||||
elif isinstance(args, torch.Tensor):
|
||||
output = args
|
||||
def post_forward(self, output, **kwargs):
|
||||
if isinstance(output, tuple):
|
||||
hidden = output[0]
|
||||
elif isinstance(output, torch.Tensor):
|
||||
hidden = output
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
modified_output = self.adapter_output + output
|
||||
return modified_output
|
||||
modified_output = self.adapter_output + hidden
|
||||
if isinstance(output, tuple):
|
||||
output = (modified_output,) + output[1:]
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output = modified_output
|
||||
else:
|
||||
raise TypeError
|
||||
return output
|
||||
|
||||
|
||||
|
||||
|
@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase):
|
|||
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
||||
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
|
||||
non_linearity (:obj:`str`): The non linearity of the adapter.
|
||||
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
||||
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
|
||||
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
||||
|
||||
|
@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase):
|
|||
_, _, ref = self.find_module(module, key)
|
||||
if self.ith % 2 == 0:
|
||||
adapterlayer = self.new_module_like(ref)
|
||||
self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||
else:
|
||||
adapterlayer = self.delta_moduels[-1]
|
||||
self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||
self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||
if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]:
|
||||
adapterlayer = self.delta_modules[-1]
|
||||
self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||
self.ith |= 1
|
||||
self.ith += 1
|
||||
self.ith %= len(self.modified_modules)
|
||||
|
||||
def new_module_like(self, module):
|
||||
module_device = get_device(module)
|
||||
|
|
|
@ -4,26 +4,50 @@
|
|||
from opendelta.utils.decorate import decorate
|
||||
from collections import OrderedDict
|
||||
|
||||
def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
if hasattr(delta_module, "pre_forward"):
|
||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||
ret = _org_func(*args, **kwargs)
|
||||
if hasattr(delta_module, "post_forward"):
|
||||
ret = delta_module.post_forward(ret)
|
||||
return ret
|
||||
|
||||
def before_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
if hasattr(delta_module, "pre_forward"):
|
||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||
ret = _org_func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
def after_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
ret = _org_func(*args, **kwargs)
|
||||
if hasattr(delta_module, "post_forward"):
|
||||
ret = delta_module.post_forward(ret)
|
||||
return ret
|
||||
|
||||
def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
ret_1 = _org_func(*args, **kwargs)
|
||||
ret_2 = delta_module.forward(*args, **kwargs)
|
||||
return ret_1 + ret_2
|
||||
|
||||
caller_map = {
|
||||
"sequential": sequential_caller,
|
||||
"parallel": parallel_caller,
|
||||
"before": before_caller,
|
||||
"after": after_caller,
|
||||
}
|
||||
|
||||
def new_replicate_for_data_parallel(self):
|
||||
r""" self is the parent module.
|
||||
"""
|
||||
# rewrite the replicate in DataParallel.
|
||||
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
if hasattr(delta_module, "pre_forward"):
|
||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||
ret = _org_func(*args, **kwargs)
|
||||
if hasattr(delta_module, "post_forward"):
|
||||
ret = delta_module.post_forward(ret)
|
||||
return ret
|
||||
|
||||
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||
args = args[1:] # the first argument here is ``self``
|
||||
delta_module = getattr(org_module, delta_name)
|
||||
ret_1 = _org_func(*args, **kwargs)
|
||||
ret_2 = delta_module.forward(*args, **kwargs)
|
||||
return ret_1 + ret_2
|
||||
replica = self.__new__(type(self))
|
||||
org_forward = replica.forward
|
||||
replica.__dict__ = self.__dict__.copy()
|
||||
|
@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self):
|
|||
|
||||
for _delta_info in self._delta_infos:
|
||||
if _delta_info['state'] == 'on':
|
||||
if _delta_info['method'] == "insert_sequential":
|
||||
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||
elif _delta_info['method'] == "insert_parallel":
|
||||
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||
if _delta_info['method'] in caller_map.keys():
|
||||
caller = caller_map[_delta_info['method']]
|
||||
new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||
else:
|
||||
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
|
||||
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
|
||||
|
|
Loading…
Reference in New Issue