init
This commit is contained in:
parent
3468eb872c
commit
0d38e6509f
|
@ -22,6 +22,7 @@ from opendelta import logging
|
||||||
from opendelta.utils.structure_mapping import CommonStructureMap
|
from opendelta.utils.structure_mapping import CommonStructureMap
|
||||||
from opendelta.utils.interactive.web import interactive
|
from opendelta.utils.interactive.web import interactive
|
||||||
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
|
from opendelta.utils.data_parallel import new_replicate_for_data_parallel
|
||||||
|
from opendelta.utils.data_parallel import caller_map
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def is_leaf_module(module):
|
def is_leaf_module(module):
|
||||||
|
@ -480,7 +481,41 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
def insert_module(self, module, method, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||||
|
if strict:
|
||||||
|
if hasattr(module.forward, "__wrapped__"):
|
||||||
|
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
|
||||||
|
|
||||||
|
# record info for plug and unplug and nested wrap
|
||||||
|
if _delta_info is None:
|
||||||
|
if delta_module is None:
|
||||||
|
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
||||||
|
|
||||||
|
_delta_info = {"method": method,
|
||||||
|
"delta_module": delta_module,
|
||||||
|
"delta_name": delta_name,
|
||||||
|
"delta_belong": self,
|
||||||
|
"state": "on"}
|
||||||
|
self._register_delta_infos(parent_module=module,
|
||||||
|
_delta_info = _delta_info)
|
||||||
|
else:
|
||||||
|
delta_module = _delta_info["delta_module"]
|
||||||
|
delta_name = _delta_info["delta_name"]
|
||||||
|
|
||||||
|
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
||||||
|
|
||||||
|
if _delta_info["method"] in caller_map.keys():
|
||||||
|
caller = caller_map[_delta_info["method"]]
|
||||||
|
new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
||||||
|
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
||||||
|
# for DataParallel's copy behavior. Experimental:
|
||||||
|
# may have bugs when module.forward is nestedly wrapped.
|
||||||
|
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||||
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
|
||||||
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
function of the original module to firstly pass the arguments into the new module's forward function and then pass
|
||||||
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
|
||||||
|
@ -496,46 +531,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
original delta is passed through ``_delta_info``.
|
original delta is passed through ``_delta_info``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info)
|
||||||
args = args[1:] # the first argument here is ``self``
|
|
||||||
delta_module = getattr(org_module, delta_name)
|
|
||||||
if hasattr(delta_module, "pre_forward"):# is not None:
|
|
||||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
|
||||||
# from IPython import embed
|
|
||||||
# embed(header = "true")
|
|
||||||
ret = _org_func(*args, **kwargs)
|
|
||||||
if hasattr(delta_module, "post_forward"):# is not None:
|
|
||||||
ret = delta_module.post_forward(ret)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
if hasattr(module.forward, "__wrapped__"):
|
|
||||||
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
|
|
||||||
|
|
||||||
# record info for plug and unplug and nested wrap
|
|
||||||
if _delta_info is None:
|
|
||||||
if delta_module is None:
|
|
||||||
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
|
||||||
|
|
||||||
_delta_info = {"method": "insert_sequential",
|
|
||||||
"delta_module": delta_module,
|
|
||||||
"delta_name": delta_name,
|
|
||||||
"delta_belong": self,
|
|
||||||
"state": "on"}
|
|
||||||
self._register_delta_infos(parent_module=module,
|
|
||||||
_delta_info = _delta_info)
|
|
||||||
else:
|
|
||||||
delta_module = _delta_info["delta_module"]
|
|
||||||
delta_name = _delta_info["delta_name"]
|
|
||||||
|
|
||||||
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
|
||||||
|
|
||||||
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
|
||||||
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
|
||||||
# for DataParallel's copy behavior. Experimental:
|
|
||||||
# may have bugs when module.forward is nestedly wrapped.
|
|
||||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
|
||||||
|
|
||||||
|
|
||||||
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
|
||||||
|
@ -555,40 +551,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
|
self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info)
|
||||||
args = args[1:] # the first argument here is ``self``
|
|
||||||
delta_module = getattr(org_module, delta_name)
|
|
||||||
ret_1 = _org_func(*args, **kwargs)
|
|
||||||
ret_2 = delta_module.forward(*args, **kwargs)
|
|
||||||
return ret_1 + ret_2
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
if hasattr(module.forward, "__wrapped__"):
|
|
||||||
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
|
|
||||||
|
|
||||||
# record info for plug and unplug and nested wrap
|
|
||||||
if _delta_info is None:
|
|
||||||
if delta_module is None:
|
|
||||||
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
|
|
||||||
|
|
||||||
_delta_info = {"method": "insert_parallel",
|
|
||||||
"delta_module": delta_module,
|
|
||||||
"delta_name": delta_name,
|
|
||||||
"delta_belong": self,
|
|
||||||
"state": "on"}
|
|
||||||
self._register_delta_infos(parent_module=module,
|
|
||||||
_delta_info = _delta_info)
|
|
||||||
else:
|
|
||||||
delta_module = _delta_info["delta_module"]
|
|
||||||
delta_name = _delta_info["delta_name"]
|
|
||||||
|
|
||||||
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
|
|
||||||
|
|
||||||
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
|
|
||||||
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
|
|
||||||
# for DataParallel's copy behavior. Experimental:
|
|
||||||
# may have bugs when module.forward is nestedly wrapped.
|
|
||||||
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
|
|
||||||
|
|
||||||
|
|
||||||
def set_active_state_dict(self, module: nn.Module):
|
def set_active_state_dict(self, module: nn.Module):
|
||||||
|
|
|
@ -44,6 +44,12 @@ class LowRankLinear(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
||||||
|
|
||||||
|
# def pre_forward(self, *args, **kwargs):
|
||||||
|
# return (args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling,), {}
|
||||||
|
|
||||||
|
# def post_forward(self, *args, **kwargs):
|
||||||
|
# return args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseDeltaConfig):
|
class LoraConfig(BaseDeltaConfig):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module):
|
||||||
|
|
||||||
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
|
self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device))
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# If we want to have a layer norm on output, we apply it later after a separate residual connection
|
|
||||||
# This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer
|
|
||||||
# if self.add_layer_norm_after:
|
|
||||||
# self.adapter_norm_after = nn.LayerNorm(self.input_size)
|
|
||||||
|
|
||||||
self.instantiated = True
|
self.instantiated = True
|
||||||
# initialize the weight, which is important for fast convergence and better performance.
|
# initialize the weight, which is important for fast convergence and better performance.
|
||||||
self.apply(self._init_weight)
|
self.apply(self._init_weight)
|
||||||
|
@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module):
|
||||||
self.instantiate(hidden_dim=self.hidden_dim)
|
self.instantiate(hidden_dim=self.hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens?
|
self.adapter_output = self.modulelist(hiddens) * self.scaled
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
def post_forward(self, *args, **kwargs):
|
def post_forward(self, output, **kwargs):
|
||||||
if isinstance(args, tuple):
|
if isinstance(output, tuple):
|
||||||
output = args[0]
|
hidden = output[0]
|
||||||
elif isinstance(args, torch.Tensor):
|
elif isinstance(output, torch.Tensor):
|
||||||
output = args
|
hidden = output
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
modified_output = self.adapter_output + output
|
modified_output = self.adapter_output + hidden
|
||||||
return modified_output
|
if isinstance(output, tuple):
|
||||||
|
output = (modified_output,) + output[1:]
|
||||||
|
elif isinstance(output, torch.Tensor):
|
||||||
|
output = modified_output
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
||||||
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
|
bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck.
|
||||||
non_linearity (:obj:`str`): The non linearity of the adapter.
|
non_linearity (:obj:`str`): The non linearity of the adapter.
|
||||||
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output.
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters.
|
||||||
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
|
||||||
|
|
||||||
|
@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase):
|
||||||
_, _, ref = self.find_module(module, key)
|
_, _, ref = self.find_module(module, key)
|
||||||
if self.ith % 2 == 0:
|
if self.ith % 2 == 0:
|
||||||
adapterlayer = self.new_module_like(ref)
|
adapterlayer = self.new_module_like(ref)
|
||||||
self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
|
self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||||
else:
|
if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]:
|
||||||
adapterlayer = self.delta_moduels[-1]
|
adapterlayer = self.delta_modules[-1]
|
||||||
self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter")
|
self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter")
|
||||||
|
self.ith |= 1
|
||||||
self.ith += 1
|
self.ith += 1
|
||||||
|
self.ith %= len(self.modified_modules)
|
||||||
|
|
||||||
def new_module_like(self, module):
|
def new_module_like(self, module):
|
||||||
module_device = get_device(module)
|
module_device = get_device(module)
|
||||||
|
|
|
@ -4,26 +4,50 @@
|
||||||
from opendelta.utils.decorate import decorate
|
from opendelta.utils.decorate import decorate
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
if hasattr(delta_module, "pre_forward"):
|
||||||
|
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||||
|
ret = _org_func(*args, **kwargs)
|
||||||
|
if hasattr(delta_module, "post_forward"):
|
||||||
|
ret = delta_module.post_forward(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def before_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
if hasattr(delta_module, "pre_forward"):
|
||||||
|
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
||||||
|
ret = _org_func(*args, **kwargs)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def after_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
ret = _org_func(*args, **kwargs)
|
||||||
|
if hasattr(delta_module, "post_forward"):
|
||||||
|
ret = delta_module.post_forward(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
||||||
|
args = args[1:] # the first argument here is ``self``
|
||||||
|
delta_module = getattr(org_module, delta_name)
|
||||||
|
ret_1 = _org_func(*args, **kwargs)
|
||||||
|
ret_2 = delta_module.forward(*args, **kwargs)
|
||||||
|
return ret_1 + ret_2
|
||||||
|
|
||||||
|
caller_map = {
|
||||||
|
"sequential": sequential_caller,
|
||||||
|
"parallel": parallel_caller,
|
||||||
|
"before": before_caller,
|
||||||
|
"after": after_caller,
|
||||||
|
}
|
||||||
|
|
||||||
def new_replicate_for_data_parallel(self):
|
def new_replicate_for_data_parallel(self):
|
||||||
r""" self is the parent module.
|
r""" self is the parent module.
|
||||||
"""
|
"""
|
||||||
# rewrite the replicate in DataParallel.
|
# rewrite the replicate in DataParallel.
|
||||||
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
|
||||||
args = args[1:] # the first argument here is ``self``
|
|
||||||
delta_module = getattr(org_module, delta_name)
|
|
||||||
if hasattr(delta_module, "pre_forward"):
|
|
||||||
args, kwargs = delta_module.pre_forward(*args, **kwargs)
|
|
||||||
ret = _org_func(*args, **kwargs)
|
|
||||||
if hasattr(delta_module, "post_forward"):
|
|
||||||
ret = delta_module.post_forward(ret)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
|
|
||||||
args = args[1:] # the first argument here is ``self``
|
|
||||||
delta_module = getattr(org_module, delta_name)
|
|
||||||
ret_1 = _org_func(*args, **kwargs)
|
|
||||||
ret_2 = delta_module.forward(*args, **kwargs)
|
|
||||||
return ret_1 + ret_2
|
|
||||||
replica = self.__new__(type(self))
|
replica = self.__new__(type(self))
|
||||||
org_forward = replica.forward
|
org_forward = replica.forward
|
||||||
replica.__dict__ = self.__dict__.copy()
|
replica.__dict__ = self.__dict__.copy()
|
||||||
|
@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self):
|
||||||
|
|
||||||
for _delta_info in self._delta_infos:
|
for _delta_info in self._delta_infos:
|
||||||
if _delta_info['state'] == 'on':
|
if _delta_info['state'] == 'on':
|
||||||
if _delta_info['method'] == "insert_sequential":
|
if _delta_info['method'] in caller_map.keys():
|
||||||
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
caller = caller_map[_delta_info['method']]
|
||||||
elif _delta_info['method'] == "insert_parallel":
|
new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
||||||
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
|
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
|
||||||
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
|
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
|
||||||
|
|
Loading…
Reference in New Issue