fix split and add batchsplit

This commit is contained in:
Achazwl 2022-07-24 19:42:31 +08:00
parent f8db5be89b
commit 5e2c8e4c83
5 changed files with 314 additions and 79 deletions

View File

@ -0,0 +1,43 @@
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-cased")
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import SplitModel
delta_model = SplitModel(model)
print("split_attach here")
delta_model.split_attach("output.dense", "adapter_A", "adapter", bottleneck_dim=12)
delta_model.split_attach("output.dense", "adapter_B", "adapter", bottleneck_dim=16, non_linearity="relu")
# delta_model.split_attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "lora", r=8)
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("batchsplit_attach here")
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "batch_lora", r=4)
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_B", "batch_lora", r=8)
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_E", "batch_adapter")
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_F", "batch_adapter")
delta_model.update()
delta_model.log()
print("split_detach and save here")
delta_model.save_split("adapter_A", "adapter_A_split.pt")
delta_model.split_detach("adapter_A")
delta_model.save_split("lora_A", "lora_A_split.pt")
delta_model.split_detach("lora_A")
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("load back here")
delta_model.load_split("adapter_A", "adapter_A_split.pt")
delta_model.load_split("lora_A", "lora_A_split.pt")
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("run here")
import torch
x = torch.randint(0, 10, (16, 128)).cuda()
delta_model.set_batchsplit_pattern(['lora_A']*4 + ['lora_B']*12)
model = model.cuda()
y = model(x)

View File

@ -5,6 +5,16 @@ import torch.nn as nn
import torch.nn as nn import torch.nn as nn
from transformers.activations import get_activation from transformers.activations import get_activation
def swish(x):
return x * torch.sigmoid(x)
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Activations(nn.Module): class Activations(nn.Module):
""" """
Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink
@ -17,20 +27,8 @@ class Activations(nn.Module):
elif activation_type.lower() == "tanh": elif activation_type.lower() == "tanh":
self.f = torch.tanh self.f = torch.tanh
elif activation_type.lower() == "swish": elif activation_type.lower() == "swish":
def swish(x):
return x * torch.sigmoid(x)
self.f = swish self.f = swish
elif activation_type.lower() == "gelu_new": elif activation_type.lower() == "gelu_new":
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
self.f = gelu_new self.f = gelu_new
elif activation_type.lower() == "gelu_orig": elif activation_type.lower() == "gelu_orig":
self.f = nn.functional.gelu self.f = nn.functional.gelu

View File

@ -9,6 +9,10 @@ import torch.nn as nn
from opendelta import BaseDeltaConfig from opendelta import BaseDeltaConfig
import math import math
class Identical(nn.Module):
def forward(self, x):
return x
class LowRankLinear(nn.Module): class LowRankLinear(nn.Module):
# ------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation. All rights reserved.
@ -30,7 +34,7 @@ class LowRankLinear(nn.Module):
if lora_dropout > 0.: if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout) self.lora_dropout = nn.Dropout(p=lora_dropout)
else: else:
self.lora_dropout = lambda x: x self.lora_dropout = Identical()
if r > 0: if r > 0:
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))

View File

@ -1,15 +1,14 @@
from functools import partial from functools import partial
from hashlib import sha1 from hashlib import sha1
from html.entities import name2codepoint
from random import random from random import random
from sqlite3 import adapters from sqlite3 import adapters
from typing import Optional, Union from typing import Optional, Union
from cv2 import accumulate from opendelta.utils.signature import get_arg_names_inside_func, signature
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import * from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase from opendelta.basemodel import DeltaBase
import loralib as lora
import torch.nn as nn import torch.nn as nn
import torch import torch
import math import math
@ -20,7 +19,6 @@ import opendelta.utils.logging as logging
import numpy as np import numpy as np
from opendelta import global_setting from opendelta import global_setting
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
from itertools import accumulate
from opendelta.delta_models.adapter import AdapterLayer from opendelta.delta_models.adapter import AdapterLayer
from opendelta.delta_models.lora import LowRankLinear from opendelta.delta_models.lora import LowRankLinear
@ -31,17 +29,44 @@ class _SplitLayer(nn.Module):
super().__init__() super().__init__()
self.module_dict = nn.ModuleDict() self.module_dict = nn.ModuleDict()
def attach(self, module_name: str, module: nn.Module): def split_attach(self, module_name: str, module: nn.Module):
if module_name in self.module_dict: if module_name in self.module_dict:
return False return False
self.module_dict[module_name] = module self.module_dict[module_name] = module
return True return True
def detach(self, module_name: str): def split_detach(self, module_name: str):
if module_name not in self.module_dict: if module_name not in self.module_dict:
return False return None
self.module_dict.pop(module_name) return self.module_dict.pop(module_name)
return True
def split_get(self, module_name: str):
if module_name not in self.module_dict:
return None
return self.module_dict[module_name]
class _BatchSplitLayer(_SplitLayer):
r"""A layer of batch splitting module.
"""
def __init__(self):
super().__init__()
self.split_pattern = {}
def set_batchsplit_pattern(self, list_pattern):
self.split_pattern = {}
self.list_pattern = list_pattern
self.count = []
for i, name in enumerate(list_pattern):
if name not in self.split_pattern:
self.split_pattern[name] = []
self.count.append(len(self.split_pattern[name]))
self.split_pattern[name].append(i)
def get_batchsplit_pattern(self, name):
return self.split_pattern.get(name, None)
def merge_by_pattern(self, output_dict):
return torch.stack([output_dict[name][self.count[i]] for i, name in enumerate(self.list_pattern)], dim=0)
class SplitSequentialLayer(_SplitLayer): class SplitSequentialLayer(_SplitLayer):
def __init__(self): def __init__(self):
@ -60,7 +85,7 @@ class SplitSequentialLayer(_SplitLayer):
split_outputs.append( module.post_forward( split_outputs.append( module.post_forward(
hiddens hiddens
) ) ) )
print(len(split_outputs)) print("sequential", len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0) merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
if isinstance(output, tuple): if isinstance(output, tuple):
@ -69,11 +94,11 @@ class SplitSequentialLayer(_SplitLayer):
output = merge_output output = merge_output
else: else:
raise TypeError raise TypeError
print(hiddens.shape)
print(merge_output.shape)
return output return output
class SplitParallelLayer(_SplitLayer): class SplitParallelLayer(_SplitLayer):
r"""A layer of splitting module.
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -83,10 +108,58 @@ class SplitParallelLayer(_SplitLayer):
split_outputs.append( module( split_outputs.append( module(
hiddens hiddens
) ) ) )
print(len(split_outputs)) print("paralell", len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0) merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
return merge_output return merge_output
class BatchSplitSequentialLayer(_BatchSplitLayer):
def __init__(self):
super().__init__()
def post_forward(self, output):
if isinstance(output, tuple):
hiddens = output[0]
elif isinstance(output, torch.Tensor):
hiddens = output
else:
raise TypeError
split_outputs = {}
for module_name, module in self.module_dict.items():
pattern = self.get_batchsplit_pattern(module_name)
if pattern is not None:
split_outputs[module_name] = module.post_forward(
hiddens[pattern],
)
merge_output = self.merge_by_pattern(split_outputs)
print(hiddens.shape)
print(merge_output.shape)
if isinstance(output, tuple):
output = (merge_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = merge_output
else:
raise TypeError
return output
class BatchSplitParallelLayer(_BatchSplitLayer):
def __init__(self):
super().__init__()
def forward(self, hiddens):
split_outputs = {}
for module_name, module in self.module_dict.items():
pattern = self.get_batchsplit_pattern(module_name)
if pattern != None:
split_outputs[module_name] = module(
hiddens[pattern],
)
merge_output = self.merge_by_pattern(split_outputs)
print(hiddens.shape)
print(merge_output.shape)
return merge_output
class SplitConfig(BaseDeltaConfig): class SplitConfig(BaseDeltaConfig):
r""" r"""
This is the configuration class to store the configuration of a :py:class:`~SplitModel` This is the configuration class to store the configuration of a :py:class:`~SplitModel`
@ -134,7 +207,7 @@ class SplitModel(DeltaBase):
""" """
config_class = SplitConfig config_class = SplitConfig
delta_type = "batch_split" delta_type = "split"
default_modified_modules = ["attn", "ff"] default_modified_modules = ["attn", "ff"]
def __init__(self, def __init__(self,
backbone_model: nn.Module, backbone_model: nn.Module,
@ -163,6 +236,8 @@ class SplitModel(DeltaBase):
self.modified_modules, self.modified_modules,
) )
self.name2point = {}
self.batch_layer = {}
self.modified_points = {} self.modified_points = {}
def add_all_delta_to_backbone(self, def add_all_delta_to_backbone(self,
@ -187,13 +262,62 @@ class SplitModel(DeltaBase):
# create a new key list to avoid recursion. # create a new key list to avoid recursion.
self.backbone_key_list = [key for key, _ in backbone.named_modules()] self.backbone_key_list = [key for key, _ in backbone.named_modules()]
return backbone return backbone
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None): def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph.
First try to use the dummy_inputs of the pretrained model. If the model has no dummy_inputs, will try to create
integer tensor as the pseudo_input, if ``decoder_input_ids`` is in the model's forward function, additional create it.
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
"""
device = get_device(module)
logger.warning("No dummy_inputs attributes, create a common input_ids for input.")
if len(self.batch_layer) > 0:
pseudo_input = torch.tensor([[0,0,0,0]]*len(self.batch_layer)).to(device)
self.set_batchsplit_pattern(list(self.batch_layer.keys()))
else:
pseudo_input = torch.tensor([[0,0,0,0]]).to(device)
print(pseudo_input)
if "decoder_input_ids" in signature(module.forward).args:
module(pseudo_input, decoder_input_ids = pseudo_input)
else:
module(pseudo_input)
def update(self):
self._pseudo_data_to_instantiate(self.backbone_model)
self.mark_as_delta()
def set_batchsplit_pattern(self,
pattern: List,
):
r"""Set the batch split pattern.
Args:
pattern (:obj:`List`): The batch split pattern.
"""
for module_name, layer_list in self.batch_layer.items():
for batch_layer in layer_list:
batch_layer.set_batchsplit_pattern(pattern)
def split_attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, **kwargs):
if module_name in self.modified_points:
raise ValueError(f"{module_name} already in delta model")
if isinstance(modified_point, str): if isinstance(modified_point, str):
modified_point = [modified_point] modified_point = [modified_point]
if module_type not in ["adapter", "lora"]: self.name2point[module_name] = (module_type, modified_point)
raise ValueError("module_type must be either adapter or lora")
advailable = ["adapter", "lora"]
advailable += [f"batch_{a}" for a in advailable]
if module_type not in advailable:
raise ValueError(f"module_type must be in {' '.join(advailable)}.")
if module_type.startswith("batch_"):
self.batch_layer[module_name] = []
for key in self.backbone_key_list: for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
@ -204,32 +328,42 @@ class SplitModel(DeltaBase):
if module_type == "adapter": if module_type == "adapter":
splitlayer = SplitSequentialLayer() splitlayer = SplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential") self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
if module_type == "lora": elif module_type == "lora":
splitlayer = SplitParallelLayer() splitlayer = SplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel") self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
elif module_type == "batch_adapter":
splitlayer = BatchSplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
elif module_type == "batch_lora":
splitlayer = BatchSplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
self.modified_points[key] = splitlayer self.modified_points[key] = splitlayer
splitlayer = self.modified_points[key] splitlayer = self.modified_points[key]
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \ if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)): (module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
raise ValueError("one modified_point can have at most one module_type") raise ValueError("one modified_point can have at most one module_type")
if module is None: if module_type.startswith("batch_"):
if module_type == "adapter": self.batch_layer[module_name].append(splitlayer)
module = self.new_adapter_like(ref) delta_type = module_type[6:]
if module_type == "lora": else:
module = self.new_lora_like(ref) delta_type = module_type
if not splitlayer.attach(module_name, module): if delta_type == "adapter":
module = self.new_adapter_like(ref, **kwargs)
elif delta_type == "lora":
module = self.new_lora_like(ref, **kwargs)
if not splitlayer.split_attach(module_name, module):
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key)) raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
def update(self):
self._pseudo_data_to_instantiate(self.backbone_model)
self.mark_as_delta()
def detach(self, modified_point: str, module_name: str): def split_detach(self, module_name: str):
if isinstance(modified_point, str): if module_name not in self.name2point:
modified_point = [modified_point] raise ValueError(f"{module_name} not in delta model")
module_type, modified_point = self.name2point.pop(module_name)
for key in self.backbone_key_list: for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
@ -241,23 +375,105 @@ class SplitModel(DeltaBase):
splitlayer = self.modified_points[key] splitlayer = self.modified_points[key]
if not splitlayer.detach(module_name): module = splitlayer.split_detach(module_name)
if module is None:
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
if module_type.startswith("batch_"):
self.batch_layer.pop(module_name)
def save_split(self, module_name: str, save_name: str):
if module_name not in self.name2point:
raise ValueError(f"{module_name} not in delta model")
module_type, modified_point = self.name2point[module_name]
print("Save", module_name, modified_point)
module_dict = nn.ModuleDict()
for key in self.backbone_key_list:
print("find", key, modified_point, self.find_key(key, modified_point))
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
if key not in self.modified_points:
raise ValueError("no module has been added to {}".format(key))
splitlayer = self.modified_points[key]
module = splitlayer.split_get(module_name)
if module is None:
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key)) raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
def new_adapter_like(self, module): module_dict[f"{module_type}:{key.replace('.', ':')}"] = module
adapterlayer = AdapterLayer()
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
torch.save(module_dict, save_name)
def load_split(self, module_name: str, load_name: str):
if module_name in self.modified_points:
raise ValueError(f"{module_name} already in delta model")
module_dict = torch.load(load_name)
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
keys = [key.split(':',maxsplit=1)[1].replace(':', '.') for key in module_dict.keys()]
module_types = [key.split(':',maxsplit=1)[0] for key in module_dict.keys()]
module_type = module_types[0]
print(keys)
print(module_type)
self.name2point[module_name] = (module_type, keys)
if module_types[0].startswith("batch_"):
self.batch_layer[module_name] = []
for key in self.backbone_key_list:
if key in keys:
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
module = module_dict[list(module_dict.keys())[keys.index(key)]]
if key not in self.modified_points:
if module_type == "adapter":
splitlayer = SplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
elif module_type == "lora":
splitlayer = SplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
elif module_type == "batch_adapter":
splitlayer = BatchSplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
elif module_type == "batch_lora":
splitlayer = BatchSplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
self.modified_points[key] = splitlayer
splitlayer = self.modified_points[key]
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
raise ValueError("one modified_point can have at most one module_type")
if module_type.startswith("batch_"):
self.batch_layer[module_name].append(splitlayer)
if not splitlayer.split_attach(module_name, module):
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
def new_adapter_like(self, module, **kwargs):
adapterlayer = AdapterLayer(**kwargs)
self.delta_modules.append(adapterlayer) self.delta_modules.append(adapterlayer)
return adapterlayer return adapterlayer
def new_lora_like(self, child_module): def new_lora_like(self, child_module, **kwargs):
if isinstance(child_module, nn.Linear): if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features in_features, out_features = child_module.in_features, child_module.out_features
new_module = LowRankLinear(in_features = in_features, new_module = LowRankLinear(in_features = in_features,
out_features = out_features, out_features = out_features,
weight = child_module.weight, weight = child_module.weight,
r=self.lora_r, **kwargs,)
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
self.delta_modules.append(new_module) self.delta_modules.append(new_module)
else: else:
raise NotImplementedError raise NotImplementedError

26
test.py
View File

@ -1,26 +0,0 @@
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-cased")
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import SplitModel
delta_model = SplitModel(model)
print("attach here")
delta_model.attach("output.dense", "adapter_A", "adapter")
delta_model.attach("output.dense", "adapter_B", "adapter")
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
delta_model.update()
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
delta_model.log() # This will visualize the backbone after modification and other information.
import torch
x = torch.randint(0, 10, (16, 128)).cuda()
import time
model = model.cuda()
st_time = time.time()
print("run here")
y = model(x)
print("detach here")
delta_model.detach("3.output.dense", "adapter_A")
delta_model.log()
print("run here")
y = model(x)
print(time.time() - st_time)