This commit is contained in:
Achazwl 2022-02-26 09:00:12 +08:00
parent 3468eb872c
commit 0d38e6509f
4 changed files with 108 additions and 114 deletions

View File

@ -22,6 +22,7 @@ from opendelta import logging
from opendelta.utils.structure_mapping import CommonStructureMap
from opendelta.utils.interactive.web import interactive
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
from opendelta.utils.data_parallel import caller_map
logger = logging.get_logger(__name__)
def is_leaf_module(module):
@ -480,7 +481,41 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
raise NotImplementedError
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
def insert_module(self, module, method, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": method,
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
if _delta_info["method"] in caller_map.keys():
caller = caller_map[_delta_info["method"]]
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
else:
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
@ -496,46 +531,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
original delta is passed through ``_delta_info``.
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):# is not None:
args, kwargs = delta_module.pre_forward(*args, **kwargs)
# from IPython import embed
# embed(header = "true")
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):# is not None:
ret = delta_module.post_forward(ret)
return ret
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_sequential",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
@ -555,40 +551,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_parallel",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
def set_active_state_dict(self, module: nn.Module):

View File

@ -44,6 +44,12 @@ class LowRankLinear(nn.Module):
def forward(self, x):
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
# def pre_forward(self, *args, **kwargs):
# return (args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling,), {}
# def post_forward(self, *args, **kwargs):
# return args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling
class LoraConfig(BaseDeltaConfig):
r"""

View File

@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module):
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
# TODO:
# If we want to have a layer norm on output, we apply it later after a separate residual connection
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
# if self.add_layer_norm_after:
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module):
self.instantiate(hidden_dim=self.hidden_dim)
self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens?
self.adapter_output = self.modulelist(hiddens) * self.scaled
return args, kwargs
def post_forward(self, *args, **kwargs):
if isinstance(args, tuple):
output = args[0]
elif isinstance(args, torch.Tensor):
output = args
def post_forward(self, output, **kwargs):
if isinstance(output, tuple):
hidden = output[0]
elif isinstance(output, torch.Tensor):
hidden = output
else:
raise TypeError
modified_output = self.adapter_output + output
return modified_output
modified_output = self.adapter_output + hidden
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = modified_output
else:
raise TypeError
return output
@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase):
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
non_linearity (:obj:`str`): The non linearity of the adapter.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase):
_, _, ref = self.find_module(module, key)
if self.ith % 2 == 0:
adapterlayer = self.new_module_like(ref)
self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
else:
adapterlayer = self.delta_moduels[-1]
self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter")
if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]:
adapterlayer = self.delta_modules[-1]
self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter")
self.ith |= 1
self.ith += 1
self.ith %= len(self.modified_modules)
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -4,26 +4,50 @@
from opendelta.utils.decorate import decorate
from collections import OrderedDict
def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def before_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
return ret
def after_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
caller_map = {
"sequential": sequential_caller,
"parallel": parallel_caller,
"before": before_caller,
"after": after_caller,
}
def new_replicate_for_data_parallel(self):
r""" self is the parent module.
"""
# rewrite the replicate in DataParallel.
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
replica = self.__new__(type(self))
org_forward = replica.forward
replica.__dict__ = self.__dict__.copy()
@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self):
for _delta_info in self._delta_infos:
if _delta_info['state'] == 'on':
if _delta_info['method'] == "insert_sequential":
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
elif _delta_info['method'] == "insert_parallel":
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
if _delta_info['method'] in caller_map.keys():
caller = caller_map[_delta_info['method']]
new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
else:
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))