diff --git a/opendelta/basemodel.py b/opendelta/basemodel.py index fbeab7f..6ed815e 100644 --- a/opendelta/basemodel.py +++ b/opendelta/basemodel.py @@ -26,6 +26,7 @@ from opendelta.utils.data_parallel import new_replicate_for_data_parallel from opendelta.utils.cuda import move_dict_to_cuda import sys +from opendelta.utils.data_parallel import caller_map logger = logging.get_logger(__name__) def is_leaf_module(module): @@ -531,7 +532,7 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ raise NotImplementedError - def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): + def insert_module(self, module, method='sequential', delta_module=None, delta_name='delta', strict=False, _delta_info=None): r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward function of the original module to firstly pass the arguments into the new module's forward function and then pass it into the original ones. The new module can also be inserted after the original module with similar mechanism. @@ -547,15 +548,6 @@ class DeltaBase(nn.Module, SaveLoadMixin): original delta is passed through ``_delta_info``. """ - def _caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - if hasattr(delta_module, "pre_forward"):# is not None: - args, kwargs = delta_module.pre_forward(*args, **kwargs) - ret = _org_func(*args, **kwargs) - if hasattr(delta_module, "post_forward"):# is not None: - ret = delta_module.post_forward(ret) - return ret if strict: @@ -566,9 +558,9 @@ class DeltaBase(nn.Module, SaveLoadMixin): if _delta_info is None: if delta_module is None: raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") - - _delta_info = {"method": "insert_sequential", - "delta_module": delta_module, + + _delta_info = {"method": method, + "delta_module": delta_module, "delta_name": delta_name, "delta_belong": self, "state": "on"} @@ -580,12 +572,36 @@ class DeltaBase(nn.Module, SaveLoadMixin): setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) - new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). - module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method - # for DataParallel's copy behavior. Experimental: - # may have bugs when module.forward is nestedly wrapped. - module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + if _delta_info["method"] in caller_map.keys(): + caller = caller_map[_delta_info["method"]] + new_forward = decorate(module.forward, caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). + module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method + # for DataParallel's copy behavior. Experimental: + # may have bugs when module.forward is nestedly wrapped. + module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) + else: + raise NotImplementedError(f"_delta_info['method']=='{_delta_info['method']}' is not supported") + + + def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): + r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward + function of the original module to firstly pass the arguments into the new module's forward function and then pass + it into the original ones. The new module can also be inserted after the original module with similar mechanism. + + When implementing the new module , researchers should be aware of the components of arguments of the original module's forward function. + + Args: + module: (:obj:`nn.Module`): The (sub)module to inserted a delta module. + delta_module: (:obj:`DeltaBase`): The delta module to be inserted. + name: (:obj:`str`, *optional*): The name of the delta in the backbone module. + strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module. + _delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of + original delta is passed through ``_delta_info``. + + """ + self.insert_module(module, "sequential", delta_module, delta_name, strict, _delta_info) + def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None): """insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward @@ -604,41 +620,8 @@ class DeltaBase(nn.Module, SaveLoadMixin): """ - def _caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - ret_1 = _org_func(*args, **kwargs) - ret_2 = delta_module.forward(*args, **kwargs) - return ret_1 + ret_2 - - if strict: - if hasattr(module.forward, "__wrapped__"): - raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?") - - # record info for plug and unplug and nested wrap - if _delta_info is None: - if delta_module is None: - raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.") - - _delta_info = {"method": "insert_parallel", - "delta_module": delta_module, - "delta_name": delta_name, - "delta_belong": self, - "state": "on"} - self._register_delta_infos(parent_module=module, - _delta_info = _delta_info) - else: - delta_module = _delta_info["delta_module"] - delta_name = _delta_info["delta_name"] - - setattr(module, _delta_info['delta_name'], _delta_info["delta_module"]) - - new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.). - module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method - # for DataParallel's copy behavior. Experimental: - # may have bugs when module.forward is nestedly wrapped. - module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module)) - + self.insert_module(module, "parallel", delta_module, delta_name, strict, _delta_info) + def set_active_state_dict(self, module: nn.Module): r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part. diff --git a/opendelta/delta_models/parallel_adapter.py b/opendelta/delta_models/parallel_adapter.py index 7290614..b6530f3 100644 --- a/opendelta/delta_models/parallel_adapter.py +++ b/opendelta/delta_models/parallel_adapter.py @@ -49,12 +49,6 @@ class ParallelAdapterLayer(nn.Module): self.modulelist.add_module("up_proj", nn.Linear(self.bottleneck_dim, self.hidden_dim, device=self.device)) - # TODO: - # If we want to have a layer norm on output, we apply it later after a separate residual connection - # This means that we learn a new output layer norm, which replaces another layer norm learned in the bert layer - # if self.add_layer_norm_after: - # self.adapter_norm_after = nn.LayerNorm(self.input_size) - self.instantiated = True # initialize the weight, which is important for fast convergence and better performance. self.apply(self._init_weight) @@ -85,19 +79,25 @@ class ParallelAdapterLayer(nn.Module): self.instantiate(hidden_dim=self.hidden_dim) - self.adapter_output = self.modulelist(hiddens) * self.scaled + hiddens # TODO add hiddens? + self.adapter_output = self.modulelist(hiddens) * self.scaled return args, kwargs - def post_forward(self, *args, **kwargs): - if isinstance(args, tuple): - output = args[0] - elif isinstance(args, torch.Tensor): - output = args + def post_forward(self, output, **kwargs): + if isinstance(output, tuple): + hidden = output[0] + elif isinstance(output, torch.Tensor): + hidden = output else: raise TypeError - modified_output = self.adapter_output + output - return modified_output + modified_output = self.adapter_output + hidden + if isinstance(output, tuple): + output = (modified_output,) + output[1:] + elif isinstance(output, torch.Tensor): + output = modified_output + else: + raise TypeError + return output @@ -141,7 +141,7 @@ class ParallelAdapterModel(DeltaBase): backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified. bottleneck_dim (:obj:`int`): The dimension of the adapter's bottleneck. non_linearity (:obj:`str`): The non linearity of the adapter. - modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output. + modified_modules (:obj:`List[str]`): modules to add parallel adapter. Must be paired and have the save order in layer. For examples, ["attn", "attn", "ff.w1", "ff.w2"] add one parallel adapter from attn's input to attn's output, and another one from ff.w1's input to ff.w2's output. unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the parallel adapter parameters. common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping. @@ -182,11 +182,13 @@ class ParallelAdapterModel(DeltaBase): _, _, ref = self.find_module(module, key) if self.ith % 2 == 0: adapterlayer = self.new_module_like(ref) - self.insert_before_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter") - else: - adapterlayer = self.delta_moduels[-1] - self.insert_after_module(ref, delta_module=adapterlayer, delta_name="parallel_adapter") + self.insert_module(ref, "before", delta_module=adapterlayer, delta_name="parallel_adapter") + if self.ith % 2 == 1 or self.modified_modules[self.ith] == self.modified_modules[self.ith + 1]: + adapterlayer = self.delta_modules[-1] + self.insert_module(ref, "after", delta_module=adapterlayer, delta_name="parallel_adapter") + self.ith |= 1 self.ith += 1 + self.ith %= len(self.modified_modules) def new_module_like(self, module): module_device = get_device(module) diff --git a/opendelta/utils/data_parallel.py b/opendelta/utils/data_parallel.py index ca0c4c0..8c32297 100644 --- a/opendelta/utils/data_parallel.py +++ b/opendelta/utils/data_parallel.py @@ -4,26 +4,50 @@ from opendelta.utils.decorate import decorate from collections import OrderedDict +def sequential_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + if hasattr(delta_module, "pre_forward"): + args, kwargs = delta_module.pre_forward(*args, **kwargs) + ret = _org_func(*args, **kwargs) + if hasattr(delta_module, "post_forward"): + ret = delta_module.post_forward(ret) + return ret + +def before_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + if hasattr(delta_module, "pre_forward"): + args, kwargs = delta_module.pre_forward(*args, **kwargs) + ret = _org_func(*args, **kwargs) + return ret + +def after_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret = _org_func(*args, **kwargs) + if hasattr(delta_module, "post_forward"): + ret = delta_module.post_forward(ret) + return ret + +def parallel_caller(_org_func, org_module, delta_name, *args, **kwargs): + args = args[1:] # the first argument here is ``self`` + delta_module = getattr(org_module, delta_name) + ret_1 = _org_func(*args, **kwargs) + ret_2 = delta_module.forward(*args, **kwargs) + return ret_1 + ret_2 + +caller_map = { + "sequential": sequential_caller, + "parallel": parallel_caller, + "before": before_caller, + "after": after_caller, +} + def new_replicate_for_data_parallel(self): r""" self is the parent module. """ # rewrite the replicate in DataParallel. - def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - if hasattr(delta_module, "pre_forward"): - args, kwargs = delta_module.pre_forward(*args, **kwargs) - ret = _org_func(*args, **kwargs) - if hasattr(delta_module, "post_forward"): - ret = delta_module.post_forward(ret) - return ret - - def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs): - args = args[1:] # the first argument here is ``self`` - delta_module = getattr(org_module, delta_name) - ret_1 = _org_func(*args, **kwargs) - ret_2 = delta_module.forward(*args, **kwargs) - return ret_1 + ret_2 replica = self.__new__(type(self)) org_forward = replica.forward replica.__dict__ = self.__dict__.copy() @@ -33,10 +57,9 @@ def new_replicate_for_data_parallel(self): for _delta_info in self._delta_infos: if _delta_info['state'] == 'on': - if _delta_info['method'] == "insert_sequential": - new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) - elif _delta_info['method'] == "insert_parallel": - new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) + if _delta_info['method'] in caller_map.keys(): + caller = caller_map[_delta_info['method']] + new_forward = decorate(replica.forward, caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True) else: raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported") replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))