From 0d38e6509f7281e20a3e92c3a62b9a4db9da3e9d Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sat, 26 Feb 2022 09:00:12 +0800 Subject: [PATCH] init --- opendelta/basemodel.py | 113 +++++++-------------- opendelta/delta_models/lora.py | 6 ++ opendelta/delta_models/parallel_adapter.py | 40 ++++---- opendelta/utils/data_parallel.py | 63 ++++++++---- 4 files changed, 108 insertions(+), 114 deletions(-) diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index ac21334..16d3af7 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -22,6 +22,7 @@ from opendelta import logging from opendelta.utils.structure_mapping import CommonStructureMap from opendelta.utils.interactive.web import interactive from opendelta.utils.data_parallel import new_replicate_for_data_parallel +from opendelta.utils.data_parallel import caller_map logger = logging.get_logger(__name__) def is_leaf_module(module): @@ -480,7 +481,41 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ raise NotImplementedError - def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): + def insert_module(self, module, method, delta_module=None, delta_name='delta', strict=False, _delta_info=None): + if strict: + if hasattr(module.forward, "__wrapped__"): + raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?") + + # record info for plug and unplug and nested wrap + if _delta_info is None: + if delta_module is None: + raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") + + _delta_info = {"method": method, + "delta_module": delta_module, + "delta_name": delta_name, + "delta_belong": self, + "state": "on"} + self._register_delta_infos(parent_module=module, + _delta_info = _delta_info) + else: + delta_module = _delta_info["delta_module"] + delta_name = _delta_info["delta_name"] + + setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) + + if _delta_info["method"] in caller_map.keys(): + caller = caller_map[_delta_info["method"]] + new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). + module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method + # for DataParallel's copy behavior. Experimental: + # may have bugs when module.forward is nestedly wrapped. + module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + else: + raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported") + + + def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward function of the original module to firstly pass the arguments into the new module's forward function and then pass it into the original ones. The new module can also be inserted after the original module with similar mechanism. @@ -496,46 +531,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): original delta is passed through ``_delta_info``. """ - def _caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - if hasattr(delta_module, "pre_forward"):# is not None: - args, kwargs = delta_module.pre_forward(*args, **kwargs) - # from IPython import embed - # embed(header = "true") - ret = _org_func(*args, **kwargs) - if hasattr(delta_module, "post_forward"):# is not None: - ret = delta_module.post_forward(ret) - return ret - - - if strict: - if hasattr(module.forward, "__wrapped__"): - raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?") - - # record info for plug and unplug and nested wrap - if _delta_info is None: - if delta_module is None: - raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") - - _delta_info = {"method": "insert_sequential", - "delta_module": delta_module, - "delta_name": delta_name, - "delta_belong": self, - "state": "on"} - self._register_delta_infos(parent_module=module, - _delta_info = _delta_info) - else: - delta_module = _delta_info["delta_module"] - delta_name = _delta_info["delta_name"] - - setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) - - new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). - module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method - # for DataParallel's copy behavior. Experimental: - # may have bugs when module.forward is nestedly wrapped. - module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info) def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): @@ -555,40 +551,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ - def _caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - ret_1 = _org_func(*args, **kwargs) - ret_2 = delta_module.forward(*args, **kwargs) - return ret_1 + ret_2 - - if strict: - if hasattr(module.forward, "__wrapped__"): - raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?") - - # record info for plug and unplug and nested wrap - if _delta_info is None: - if delta_module is None: - raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") - - _delta_info = {"method": "insert_parallel", - "delta_module": delta_module, - "delta_name": delta_name, - "delta_belong": self, - "state": "on"} - self._register_delta_infos(parent_module=module, - _delta_info = _delta_info) - else: - delta_module = _delta_info["delta_module"] - delta_name = _delta_info["delta_name"] - - setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) - - new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). - module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method - # for DataParallel's copy behavior. Experimental: - # may have bugs when module.forward is nestedly wrapped. - module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info) def set_active_state_dict(self, module: nn.Module): diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 492fea6..09399eb 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -44,6 +44,12 @@ class LowRankLinear(nn.Module): def forward(self, x): return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling + # def pre_forward(self, *args, **kwargs): + # return (args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling,), {} + + # def post_forward(self, *args, **kwargs): + # return args[0] + (self.lora_dropout(args[0]) @ self.lora_A.T @ self.lora_B.T) * self.scaling + class LoraConfig(BaseDeltaConfig): r""" diff --git a/opendelta/delta_models/parallel_adapter.py b/opendelta/delta_models/parallel_adapter.py index 7290614..b6530f3 100644 --- a/opendelta/delta_models/parallel_adapter.py +++ b/opendelta/delta_models/parallel_adapter.py @@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module): self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device)) - # TODO: - # If we want to have a layer norm on output, we apply it later after a separate residual connection - # This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer - # if self.add_layer_norm_after: - # self.adapter_norm_after = nn.LayerNorm(self.input_size) - self.instantiated = True # initialize the weight, which is important for fast convergence and better performance. self.apply(self._init_weight) @@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module): self.instantiate(hidden_dim=self.hidden_dim) - self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens? + self.adapter_output = self.modulelist(hiddens) * self.scaled return args, kwargs - def post_forward(self, *args, **kwargs): - if isinstance(args, tuple): - output = args[0] - elif isinstance(args, torch.Tensor): - output = args + def post_forward(self, output, **kwargs): + if isinstance(output, tuple): + hidden = output[0] + elif isinstance(output, torch.Tensor): + hidden = output else: raise TypeError - modified_output = self.adapter_output + output - return modified_output + modified_output = self.adapter_output + hidden + if isinstance(output, tuple): + output = (modified_output,) + output[1:] + elif isinstance(output, torch.Tensor): + output = modified_output + else: + raise TypeError + return output @@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase): backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck. non_linearity (:obj:`str`): The non linearity of the adapter. - modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output. + modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output. unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters. common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping. @@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase): _, _, ref = self.find_module(module, key) if self.ith % 2 == 0: adapterlayer = self.new_module_like(ref) - self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter") - else: - adapterlayer = self.delta_moduels[-1] - self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter") + self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter") + if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]: + adapterlayer = self.delta_modules[-1] + self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter") + self.ith |= 1 self.ith += 1 + self.ith %= len(self.modified_modules) def new_module_like(self, module): module_device = get_device(module) diff --git a/opendelta/utils/data_parallel.py b/opendelta/utils/data_parallel.py index ca0c4c0..8c32297 100644 --- a/opendelta/utils/data_parallel.py +++ b/opendelta/utils/data_parallel.py @@ -4,26 +4,50 @@ from opendelta.utils.decorate import decorate from collections import OrderedDict +def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + if hasattr(delta_module, "pre_forward"): + args, kwargs = delta_module.pre_forward(*args, **kwargs) + ret = _org_func(*args, **kwargs) + if hasattr(delta_module, "post_forward"): + ret = delta_module.post_forward(ret) + return ret + +def before_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + if hasattr(delta_module, "pre_forward"): + args, kwargs = delta_module.pre_forward(*args, **kwargs) + ret = _org_func(*args, **kwargs) + return ret + +def after_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret = _org_func(*args, **kwargs) + if hasattr(delta_module, "post_forward"): + ret = delta_module.post_forward(ret) + return ret + +def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret_1 = _org_func(*args, **kwargs) + ret_2 = delta_module.forward(*args, **kwargs) + return ret_1 + ret_2 + +caller_map = { + "sequential": sequential_caller, + "parallel": parallel_caller, + "before": before_caller, + "after": after_caller, +} + def new_replicate_for_data_parallel(self): r""" self is the parent module. """ # rewrite the replicate in DataParallel. - def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - if hasattr(delta_module, "pre_forward"): - args, kwargs = delta_module.pre_forward(*args, **kwargs) - ret = _org_func(*args, **kwargs) - if hasattr(delta_module, "post_forward"): - ret = delta_module.post_forward(ret) - return ret - - def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - ret_1 = _org_func(*args, **kwargs) - ret_2 = delta_module.forward(*args, **kwargs) - return ret_1 + ret_2 replica = self.__new__(type(self)) org_forward = replica.forward replica.__dict__ = self.__dict__.copy() @@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self): for _delta_info in self._delta_infos: if _delta_info['state'] == 'on': - if _delta_info['method'] == "insert_sequential": - new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) - elif _delta_info['method'] == "insert_parallel": - new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) + if _delta_info['method'] in caller_map.keys(): + caller = caller_map[_delta_info['method']] + new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) else: raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported") replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))