This commit is contained in:
Achazwl 2022-02-26 09:00:12 +08:00
parent 3468eb872c
commit 0d38e6509f
4 changed files with 108 additions and 114 deletions

View File

@ -22,6 +22,7 @@ from opendelta import logging
from opendelta.utils.structure_mapping import CommonStructureMap
from opendelta.utils.interactive.web import interactive
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
from opendelta.utils.data_parallel import caller_map
logger = logging.get_logger(__name__)
def is_leaf_module(module):
@ -480,6 +481,40 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
raise NotImplementedError
def insert_module(self, module, method, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": method,
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
if _delta_info["method"] in caller_map.keys():
caller = caller_map[_delta_info["method"]]
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
else:
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
@ -496,46 +531,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
original delta is passed through ``_delta_info``.
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):# is not None:
args, kwargs = delta_module.pre_forward(*args, **kwargs)
# from IPython import embed
# embed(header = "true")
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):# is not None:
ret = delta_module.post_forward(ret)
return ret
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_sequential",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
@ -555,40 +551,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_parallel",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
def set_active_state_dict(self, module: nn.Module):

View File

@ -44,6 +44,12 @@ class LowRankLinear(nn.Module):
def forward(self, x):
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
# def pre_forward(self, *args, **kwargs):
# return (args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling,), {}
# def post_forward(self, *args, **kwargs):
# return args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling
class LoraConfig(BaseDeltaConfig):
r"""

View File

@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module):
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
# TODO:
# If we want to have a layer norm on output, we apply it later after a separate residual connection
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
# if self.add_layer_norm_after:
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module):
self.instantiate(hidden_dim=self.hidden_dim)
self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens?
self.adapter_output = self.modulelist(hiddens) * self.scaled
return args, kwargs
def post_forward(self, *args, **kwargs):
if isinstance(args, tuple):
output = args[0]
elif isinstance(args, torch.Tensor):
output = args
def post_forward(self, output, **kwargs):
if isinstance(output, tuple):
hidden = output[0]
elif isinstance(output, torch.Tensor):
hidden = output
else:
raise TypeError
modified_output = self.adapter_output + output
return modified_output
modified_output = self.adapter_output + hidden
if isinstance(output, tuple):
output = (modified_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = modified_output
else:
raise TypeError
return output
@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase):
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
non_linearity (:obj:`str`): The non linearity of the adapter.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase):
_, _, ref = self.find_module(module, key)
if self.ith % 2 == 0:
adapterlayer = self.new_module_like(ref)
self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
else:
adapterlayer = self.delta_moduels[-1]
self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter")
if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]:
adapterlayer = self.delta_modules[-1]
self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter")
self.ith |= 1
self.ith += 1
self.ith %= len(self.modified_modules)
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -4,11 +4,7 @@
from opendelta.utils.decorate import decorate
from collections import OrderedDict
def new_replicate_for_data_parallel(self):
r""" self is the parent module.
"""
# rewrite the replicate in DataParallel.
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
@ -18,12 +14,40 @@ def new_replicate_for_data_parallel(self):
ret = delta_module.post_forward(ret)
return ret
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
def before_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
args, kwargs = delta_module.pre_forward(*args, **kwargs)
ret = _org_func(*args, **kwargs)
return ret
def after_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret = _org_func(*args, **kwargs)
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
caller_map = {
"sequential": sequential_caller,
"parallel": parallel_caller,
"before": before_caller,
"after": after_caller,
}
def new_replicate_for_data_parallel(self):
r""" self is the parent module.
"""
# rewrite the replicate in DataParallel.
replica = self.__new__(type(self))
org_forward = replica.forward
replica.__dict__ = self.__dict__.copy()
@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self):
for _delta_info in self._delta_infos:
if _delta_info['state'] == 'on':
if _delta_info['method'] == "insert_sequential":
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
elif _delta_info['method'] == "insert_parallel":
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
if _delta_info['method'] in caller_map.keys():
caller = caller_map[_delta_info['method']]
new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
else:
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))