From 5e2c8e4c835893efd18395d06b88531d1bb1280e Mon Sep 17 00:00:00 2001 From: Achazwl Date: Sun, 24 Jul 2022 19:42:31 +0800 Subject: [PATCH] fix split and add batchsplit --- examples/tutorial/2_split.py | 43 +++ opendelta/delta_models/layers/activations.py | 22 +- opendelta/delta_models/lora.py | 6 +- opendelta/delta_models/split.py | 296 ++++++++++++++++--- test.py | 26 -- 5 files changed, 314 insertions(+), 79 deletions(-) create mode 100644 examples/tutorial/2_split.py delete mode 100644 test.py diff --git a/examples/tutorial/2_split.py b/examples/tutorial/2_split.py new file mode 100644 index 0000000..1a1786e --- /dev/null +++ b/examples/tutorial/2_split.py @@ -0,0 +1,43 @@ +from transformers import BertModel +model = BertModel.from_pretrained("bert-base-cased") +from opendelta import Visualization +Visualization(model).structure_graph() +from opendelta import SplitModel +delta_model = SplitModel(model) + +print("split_attach here") +delta_model.split_attach("output.dense", "adapter_A", "adapter", bottleneck_dim=12) +delta_model.split_attach("output.dense", "adapter_B", "adapter", bottleneck_dim=16, non_linearity="relu") +# delta_model.split_attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter") +# delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "lora", r=8) +delta_model.update() +delta_model.log() # This will visualize the backbone after modification and other information. + +print("batchsplit_attach here") +delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "batch_lora", r=4) +delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_B", "batch_lora", r=8) +# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_E", "batch_adapter") +# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_F", "batch_adapter") +delta_model.update() +delta_model.log() + +print("split_detach and save here") +delta_model.save_split("adapter_A", "adapter_A_split.pt") +delta_model.split_detach("adapter_A") +delta_model.save_split("lora_A", "lora_A_split.pt") +delta_model.split_detach("lora_A") +delta_model.update() +delta_model.log() # This will visualize the backbone after modification and other information. + +print("load back here") +delta_model.load_split("adapter_A", "adapter_A_split.pt") +delta_model.load_split("lora_A", "lora_A_split.pt") +delta_model.update() +delta_model.log() # This will visualize the backbone after modification and other information. + +print("run here") +import torch +x = torch.randint(0, 10, (16, 128)).cuda() +delta_model.set_batchsplit_pattern(['lora_A']*4 + ['lora_B']*12) +model = model.cuda() +y = model(x) diff --git a/opendelta/delta_models/layers/activations.py b/opendelta/delta_models/layers/activations.py index 8ce4a16..b1afdd3 100644 --- a/opendelta/delta_models/layers/activations.py +++ b/opendelta/delta_models/layers/activations.py @@ -5,6 +5,16 @@ import torch.nn as nn import torch.nn as nn from transformers.activations import get_activation +def swish(x): + return x * torch.sigmoid(x) + +def gelu_new(x): + """ + Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). + Also see https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + class Activations(nn.Module): """ Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink @@ -17,20 +27,8 @@ class Activations(nn.Module): elif activation_type.lower() == "tanh": self.f = torch.tanh elif activation_type.lower() == "swish": - - def swish(x): - return x * torch.sigmoid(x) - self.f = swish elif activation_type.lower() == "gelu_new": - - def gelu_new(x): - """ - Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). - Also see https://arxiv.org/abs/1606.08415 - """ - return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - self.f = gelu_new elif activation_type.lower() == "gelu_orig": self.f = nn.functional.gelu diff --git a/opendelta/delta_models/lora.py b/opendelta/delta_models/lora.py index 8633aef..d3cf0e3 100644 --- a/opendelta/delta_models/lora.py +++ b/opendelta/delta_models/lora.py @@ -9,6 +9,10 @@ import torch.nn as nn from opendelta import BaseDeltaConfig import math +class Identical(nn.Module): + def forward(self, x): + return x + class LowRankLinear(nn.Module): # ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. @@ -30,7 +34,7 @@ class LowRankLinear(nn.Module): if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: - self.lora_dropout = lambda x: x + self.lora_dropout = Identical() if r > 0: self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) diff --git a/opendelta/delta_models/split.py b/opendelta/delta_models/split.py index 8f051b9..6e509e7 100644 --- a/opendelta/delta_models/split.py +++ b/opendelta/delta_models/split.py @@ -1,15 +1,14 @@ from functools import partial from hashlib import sha1 +from html.entities import name2codepoint from random import random from sqlite3 import adapters from typing import Optional, Union -from cv2 import accumulate -from opendelta.utils.signature import get_arg_names_inside_func +from opendelta.utils.signature import get_arg_names_inside_func, signature from opendelta.utils.name_based_addressing import * from opendelta.utils.cuda import get_device from opendelta.basemodel import DeltaBase -import loralib as lora import torch.nn as nn import torch import math @@ -20,7 +19,6 @@ import opendelta.utils.logging as logging import numpy as np from opendelta import global_setting logger = logging.get_logger(__name__) -from itertools import accumulate from opendelta.delta_models.adapter import AdapterLayer from opendelta.delta_models.lora import LowRankLinear @@ -31,17 +29,44 @@ class _SplitLayer(nn.Module): super().__init__() self.module_dict = nn.ModuleDict() - def attach(self, module_name: str, module: nn.Module): + def split_attach(self, module_name: str, module: nn.Module): if module_name in self.module_dict: return False self.module_dict[module_name] = module return True - def detach(self, module_name: str): + def split_detach(self, module_name: str): if module_name not in self.module_dict: - return False - self.module_dict.pop(module_name) - return True + return None + return self.module_dict.pop(module_name) + + def split_get(self, module_name: str): + if module_name not in self.module_dict: + return None + return self.module_dict[module_name] + +class _BatchSplitLayer(_SplitLayer): + r"""A layer of batch splitting module. + """ + def __init__(self): + super().__init__() + self.split_pattern = {} + + def set_batchsplit_pattern(self, list_pattern): + self.split_pattern = {} + self.list_pattern = list_pattern + self.count = [] + for i, name in enumerate(list_pattern): + if name not in self.split_pattern: + self.split_pattern[name] = [] + self.count.append(len(self.split_pattern[name])) + self.split_pattern[name].append(i) + + def get_batchsplit_pattern(self, name): + return self.split_pattern.get(name, None) + + def merge_by_pattern(self, output_dict): + return torch.stack([output_dict[name][self.count[i]] for i, name in enumerate(self.list_pattern)], dim=0) class SplitSequentialLayer(_SplitLayer): def __init__(self): @@ -60,7 +85,7 @@ class SplitSequentialLayer(_SplitLayer): split_outputs.append( module.post_forward( hiddens ) ) - print(len(split_outputs)) + print("sequential", len(split_outputs)) merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0) if isinstance(output, tuple): @@ -69,11 +94,11 @@ class SplitSequentialLayer(_SplitLayer): output = merge_output else: raise TypeError + print(hiddens.shape) + print(merge_output.shape) return output class SplitParallelLayer(_SplitLayer): - r"""A layer of splitting module. - """ def __init__(self): super().__init__() @@ -83,10 +108,58 @@ class SplitParallelLayer(_SplitLayer): split_outputs.append( module( hiddens ) ) - print(len(split_outputs)) + print("paralell", len(split_outputs)) merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0) return merge_output +class BatchSplitSequentialLayer(_BatchSplitLayer): + def __init__(self): + super().__init__() + + def post_forward(self, output): + if isinstance(output, tuple): + hiddens = output[0] + elif isinstance(output, torch.Tensor): + hiddens = output + else: + raise TypeError + + split_outputs = {} + for module_name, module in self.module_dict.items(): + pattern = self.get_batchsplit_pattern(module_name) + if pattern is not None: + split_outputs[module_name] = module.post_forward( + hiddens[pattern], + ) + merge_output = self.merge_by_pattern(split_outputs) + print(hiddens.shape) + print(merge_output.shape) + + if isinstance(output, tuple): + output = (merge_output,) + output[1:] + elif isinstance(output, torch.Tensor): + output = merge_output + else: + raise TypeError + return output + +class BatchSplitParallelLayer(_BatchSplitLayer): + def __init__(self): + super().__init__() + + def forward(self, hiddens): + split_outputs = {} + for module_name, module in self.module_dict.items(): + pattern = self.get_batchsplit_pattern(module_name) + if pattern != None: + split_outputs[module_name] = module( + hiddens[pattern], + ) + merge_output = self.merge_by_pattern(split_outputs) + print(hiddens.shape) + print(merge_output.shape) + return merge_output + class SplitConfig(BaseDeltaConfig): r""" This is the configuration class to store the configuration of a :py:class:`~SplitModel` @@ -134,7 +207,7 @@ class SplitModel(DeltaBase): """ config_class = SplitConfig - delta_type = "batch_split" + delta_type = "split" default_modified_modules = ["attn", "ff"] def __init__(self, backbone_model: nn.Module, @@ -163,6 +236,8 @@ class SplitModel(DeltaBase): self.modified_modules, ) + self.name2point = {} + self.batch_layer = {} self.modified_points = {} def add_all_delta_to_backbone(self, @@ -187,13 +262,62 @@ class SplitModel(DeltaBase): # create a new key list to avoid recursion. self.backbone_key_list = [key for key, _ in backbone.named_modules()] return backbone - - def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None): + + def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None): + r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph. + First try to use the dummy_inputs of the pretrained model. If the model has no dummy_inputs, will try to create + integer tensor as the pseudo_input, if ``decoder_input_ids`` is in the model's forward function, additional create it. + + Args: + module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model. + + """ + device = get_device(module) + logger.warning("No dummy_inputs attributes, create a common input_ids for input.") + if len(self.batch_layer) > 0: + pseudo_input = torch.tensor([[0,0,0,0]]*len(self.batch_layer)).to(device) + self.set_batchsplit_pattern(list(self.batch_layer.keys())) + else: + pseudo_input = torch.tensor([[0,0,0,0]]).to(device) + print(pseudo_input) + if "decoder_input_ids" in signature(module.forward).args: + module(pseudo_input, decoder_input_ids = pseudo_input) + else: + module(pseudo_input) + + def update(self): + self._pseudo_data_to_instantiate(self.backbone_model) + self.mark_as_delta() + + def set_batchsplit_pattern(self, + pattern: List, + ): + r"""Set the batch split pattern. + + Args: + pattern (:obj:`List`): The batch split pattern. + + """ + for module_name, layer_list in self.batch_layer.items(): + for batch_layer in layer_list: + batch_layer.set_batchsplit_pattern(pattern) + + def split_attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, **kwargs): + if module_name in self.modified_points: + raise ValueError(f"{module_name} already in delta model") + if isinstance(modified_point, str): modified_point = [modified_point] - if module_type not in ["adapter", "lora"]: - raise ValueError("module_type must be either adapter or lora") + self.name2point[module_name] = (module_type, modified_point) + + advailable = ["adapter", "lora"] + advailable += [f"batch_{a}" for a in advailable] + if module_type not in advailable: + raise ValueError(f"module_type must be in {' '.join(advailable)}.") + + if module_type.startswith("batch_"): + self.batch_layer[module_name] = [] for key in self.backbone_key_list: if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered @@ -204,32 +328,42 @@ class SplitModel(DeltaBase): if module_type == "adapter": splitlayer = SplitSequentialLayer() self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential") - if module_type == "lora": + elif module_type == "lora": splitlayer = SplitParallelLayer() self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel") + elif module_type == "batch_adapter": + splitlayer = BatchSplitSequentialLayer() + self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential") + elif module_type == "batch_lora": + splitlayer = BatchSplitParallelLayer() + self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel") self.modified_points[key] = splitlayer splitlayer = self.modified_points[key] if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \ - (module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)): + (module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \ + (module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \ + (module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)): raise ValueError("one modified_point can have at most one module_type") - if module is None: - if module_type == "adapter": - module = self.new_adapter_like(ref) - if module_type == "lora": - module = self.new_lora_like(ref) + if module_type.startswith("batch_"): + self.batch_layer[module_name].append(splitlayer) + delta_type = module_type[6:] + else: + delta_type = module_type - if not splitlayer.attach(module_name, module): + if delta_type == "adapter": + module = self.new_adapter_like(ref, **kwargs) + elif delta_type == "lora": + module = self.new_lora_like(ref, **kwargs) + + if not splitlayer.split_attach(module_name, module): raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key)) - - def update(self): - self._pseudo_data_to_instantiate(self.backbone_model) - self.mark_as_delta() - def detach(self, modified_point: str, module_name: str): - if isinstance(modified_point, str): - modified_point = [modified_point] + def split_detach(self, module_name: str): + if module_name not in self.name2point: + raise ValueError(f"{module_name} not in delta model") + module_type, modified_point = self.name2point.pop(module_name) for key in self.backbone_key_list: if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered @@ -241,23 +375,105 @@ class SplitModel(DeltaBase): splitlayer = self.modified_points[key] - if not splitlayer.detach(module_name): + module = splitlayer.split_detach(module_name) + if module is None: + raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key)) + + if module_type.startswith("batch_"): + self.batch_layer.pop(module_name) + + def save_split(self, module_name: str, save_name: str): + if module_name not in self.name2point: + raise ValueError(f"{module_name} not in delta model") + module_type, modified_point = self.name2point[module_name] + print("Save", module_name, modified_point) + + module_dict = nn.ModuleDict() + + for key in self.backbone_key_list: + print("find", key, modified_point, self.find_key(key, modified_point)) + if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered + logger.debug("find key: {}".format(key)) + _, _, ref = self.find_module(self.backbone_model, key) + + if key not in self.modified_points: + raise ValueError("no module has been added to {}".format(key)) + + splitlayer = self.modified_points[key] + + module = splitlayer.split_get(module_name) + if module is None: raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key)) - def new_adapter_like(self, module): - adapterlayer = AdapterLayer() + module_dict[f"{module_type}:{key.replace('.', ':')}"] = module + + print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]]) + torch.save(module_dict, save_name) + + def load_split(self, module_name: str, load_name: str): + if module_name in self.modified_points: + raise ValueError(f"{module_name} already in delta model") + + module_dict = torch.load(load_name) + print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]]) + + keys = [key.split(':',maxsplit=1)[1].replace(':', '.') for key in module_dict.keys()] + module_types = [key.split(':',maxsplit=1)[0] for key in module_dict.keys()] + module_type = module_types[0] + print(keys) + print(module_type) + + self.name2point[module_name] = (module_type, keys) + + if module_types[0].startswith("batch_"): + self.batch_layer[module_name] = [] + + for key in self.backbone_key_list: + if key in keys: + logger.debug("find key: {}".format(key)) + _, _, ref = self.find_module(self.backbone_model, key) + module = module_dict[list(module_dict.keys())[keys.index(key)]] + + if key not in self.modified_points: + if module_type == "adapter": + splitlayer = SplitSequentialLayer() + self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential") + elif module_type == "lora": + splitlayer = SplitParallelLayer() + self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel") + elif module_type == "batch_adapter": + splitlayer = BatchSplitSequentialLayer() + self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential") + elif module_type == "batch_lora": + splitlayer = BatchSplitParallelLayer() + self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel") + self.modified_points[key] = splitlayer + + splitlayer = self.modified_points[key] + if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \ + (module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \ + (module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \ + (module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)): + raise ValueError("one modified_point can have at most one module_type") + + if module_type.startswith("batch_"): + self.batch_layer[module_name].append(splitlayer) + + if not splitlayer.split_attach(module_name, module): + raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key)) + + def new_adapter_like(self, module, **kwargs): + adapterlayer = AdapterLayer(**kwargs) self.delta_modules.append(adapterlayer) return adapterlayer - def new_lora_like(self, child_module): + def new_lora_like(self, child_module, **kwargs): if isinstance(child_module, nn.Linear): in_features, out_features = child_module.in_features, child_module.out_features new_module = LowRankLinear(in_features = in_features, out_features = out_features, weight = child_module.weight, - r=self.lora_r, - lora_alpha=self.lora_alpha, - lora_dropout=self.lora_dropout) + **kwargs,) self.delta_modules.append(new_module) else: raise NotImplementedError diff --git a/test.py b/test.py deleted file mode 100644 index 16e68c3..0000000 --- a/test.py +++ /dev/null @@ -1,26 +0,0 @@ -from transformers import BertModel -model = BertModel.from_pretrained("bert-base-cased") -from opendelta import Visualization -Visualization(model).structure_graph() -from opendelta import SplitModel -delta_model = SplitModel(model) -print("attach here") -delta_model.attach("output.dense", "adapter_A", "adapter") -delta_model.attach("output.dense", "adapter_B", "adapter") -delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter") -delta_model.update() -# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora") -delta_model.log() # This will visualize the backbone after modification and other information. -import torch -x = torch.randint(0, 10, (16, 128)).cuda() -import time -model = model.cuda() -st_time = time.time() -print("run here") -y = model(x) -print("detach here") -delta_model.detach("3.output.dense", "adapter_A") -delta_model.log() -print("run here") -y = model(x) -print(time.time() - st_time) \ No newline at end of file