fix split and add batchsplit

This commit is contained in:
Achazwl 2022-07-24 19:42:31 +08:00
parent f8db5be89b
commit 5e2c8e4c83
5 changed files with 314 additions and 79 deletions

View File

@ -0,0 +1,43 @@
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-cased")
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import SplitModel
delta_model = SplitModel(model)
print("split_attach here")
delta_model.split_attach("output.dense", "adapter_A", "adapter", bottleneck_dim=12)
delta_model.split_attach("output.dense", "adapter_B", "adapter", bottleneck_dim=16, non_linearity="relu")
# delta_model.split_attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "lora", r=8)
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("batchsplit_attach here")
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "batch_lora", r=4)
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_B", "batch_lora", r=8)
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_E", "batch_adapter")
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_F", "batch_adapter")
delta_model.update()
delta_model.log()
print("split_detach and save here")
delta_model.save_split("adapter_A", "adapter_A_split.pt")
delta_model.split_detach("adapter_A")
delta_model.save_split("lora_A", "lora_A_split.pt")
delta_model.split_detach("lora_A")
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("load back here")
delta_model.load_split("adapter_A", "adapter_A_split.pt")
delta_model.load_split("lora_A", "lora_A_split.pt")
delta_model.update()
delta_model.log() # This will visualize the backbone after modification and other information.
print("run here")
import torch
x = torch.randint(0, 10, (16, 128)).cuda()
delta_model.set_batchsplit_pattern(['lora_A']*4 + ['lora_B']*12)
model = model.cuda()
y = model(x)

View File

@ -5,6 +5,16 @@ import torch.nn as nn
import torch.nn as nn
from transformers.activations import get_activation
def swish(x):
return x * torch.sigmoid(x)
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Activations(nn.Module):
"""
Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink
@ -17,20 +27,8 @@ class Activations(nn.Module):
elif activation_type.lower() == "tanh":
self.f = torch.tanh
elif activation_type.lower() == "swish":
def swish(x):
return x * torch.sigmoid(x)
self.f = swish
elif activation_type.lower() == "gelu_new":
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
self.f = gelu_new
elif activation_type.lower() == "gelu_orig":
self.f = nn.functional.gelu

View File

@ -9,6 +9,10 @@ import torch.nn as nn
from opendelta import BaseDeltaConfig
import math
class Identical(nn.Module):
def forward(self, x):
return x
class LowRankLinear(nn.Module):
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
@ -30,7 +34,7 @@ class LowRankLinear(nn.Module):
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.lora_dropout = Identical()
if r > 0:
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))

View File

@ -1,15 +1,14 @@
from functools import partial
from hashlib import sha1
from html.entities import name2codepoint
from random import random
from sqlite3 import adapters
from typing import Optional, Union
from cv2 import accumulate
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.signature import get_arg_names_inside_func, signature
from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase
import loralib as lora
import torch.nn as nn
import torch
import math
@ -20,7 +19,6 @@ import opendelta.utils.logging as logging
import numpy as np
from opendelta import global_setting
logger = logging.get_logger(__name__)
from itertools import accumulate
from opendelta.delta_models.adapter import AdapterLayer
from opendelta.delta_models.lora import LowRankLinear
@ -31,17 +29,44 @@ class _SplitLayer(nn.Module):
super().__init__()
self.module_dict = nn.ModuleDict()
def attach(self, module_name: str, module: nn.Module):
def split_attach(self, module_name: str, module: nn.Module):
if module_name in self.module_dict:
return False
self.module_dict[module_name] = module
return True
def detach(self, module_name: str):
def split_detach(self, module_name: str):
if module_name not in self.module_dict:
return False
self.module_dict.pop(module_name)
return True
return None
return self.module_dict.pop(module_name)
def split_get(self, module_name: str):
if module_name not in self.module_dict:
return None
return self.module_dict[module_name]
class _BatchSplitLayer(_SplitLayer):
r"""A layer of batch splitting module.
"""
def __init__(self):
super().__init__()
self.split_pattern = {}
def set_batchsplit_pattern(self, list_pattern):
self.split_pattern = {}
self.list_pattern = list_pattern
self.count = []
for i, name in enumerate(list_pattern):
if name not in self.split_pattern:
self.split_pattern[name] = []
self.count.append(len(self.split_pattern[name]))
self.split_pattern[name].append(i)
def get_batchsplit_pattern(self, name):
return self.split_pattern.get(name, None)
def merge_by_pattern(self, output_dict):
return torch.stack([output_dict[name][self.count[i]] for i, name in enumerate(self.list_pattern)], dim=0)
class SplitSequentialLayer(_SplitLayer):
def __init__(self):
@ -60,7 +85,7 @@ class SplitSequentialLayer(_SplitLayer):
split_outputs.append( module.post_forward(
hiddens
) )
print(len(split_outputs))
print("sequential", len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
if isinstance(output, tuple):
@ -69,11 +94,11 @@ class SplitSequentialLayer(_SplitLayer):
output = merge_output
else:
raise TypeError
print(hiddens.shape)
print(merge_output.shape)
return output
class SplitParallelLayer(_SplitLayer):
r"""A layer of splitting module.
"""
def __init__(self):
super().__init__()
@ -83,10 +108,58 @@ class SplitParallelLayer(_SplitLayer):
split_outputs.append( module(
hiddens
) )
print(len(split_outputs))
print("paralell", len(split_outputs))
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
return merge_output
class BatchSplitSequentialLayer(_BatchSplitLayer):
def __init__(self):
super().__init__()
def post_forward(self, output):
if isinstance(output, tuple):
hiddens = output[0]
elif isinstance(output, torch.Tensor):
hiddens = output
else:
raise TypeError
split_outputs = {}
for module_name, module in self.module_dict.items():
pattern = self.get_batchsplit_pattern(module_name)
if pattern is not None:
split_outputs[module_name] = module.post_forward(
hiddens[pattern],
)
merge_output = self.merge_by_pattern(split_outputs)
print(hiddens.shape)
print(merge_output.shape)
if isinstance(output, tuple):
output = (merge_output,) + output[1:]
elif isinstance(output, torch.Tensor):
output = merge_output
else:
raise TypeError
return output
class BatchSplitParallelLayer(_BatchSplitLayer):
def __init__(self):
super().__init__()
def forward(self, hiddens):
split_outputs = {}
for module_name, module in self.module_dict.items():
pattern = self.get_batchsplit_pattern(module_name)
if pattern != None:
split_outputs[module_name] = module(
hiddens[pattern],
)
merge_output = self.merge_by_pattern(split_outputs)
print(hiddens.shape)
print(merge_output.shape)
return merge_output
class SplitConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
@ -134,7 +207,7 @@ class SplitModel(DeltaBase):
"""
config_class = SplitConfig
delta_type = "batch_split"
delta_type = "split"
default_modified_modules = ["attn", "ff"]
def __init__(self,
backbone_model: nn.Module,
@ -163,6 +236,8 @@ class SplitModel(DeltaBase):
self.modified_modules,
)
self.name2point = {}
self.batch_layer = {}
self.modified_points = {}
def add_all_delta_to_backbone(self,
@ -188,12 +263,61 @@ class SplitModel(DeltaBase):
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
return backbone
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None):
def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph.
First try to use the dummy_inputs of the pretrained model. If the model has no dummy_inputs, will try to create
integer tensor as the pseudo_input, if ``decoder_input_ids`` is in the model's forward function, additional create it.
Args:
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
"""
device = get_device(module)
logger.warning("No dummy_inputs attributes, create a common input_ids for input.")
if len(self.batch_layer) > 0:
pseudo_input = torch.tensor([[0,0,0,0]]*len(self.batch_layer)).to(device)
self.set_batchsplit_pattern(list(self.batch_layer.keys()))
else:
pseudo_input = torch.tensor([[0,0,0,0]]).to(device)
print(pseudo_input)
if "decoder_input_ids" in signature(module.forward).args:
module(pseudo_input, decoder_input_ids = pseudo_input)
else:
module(pseudo_input)
def update(self):
self._pseudo_data_to_instantiate(self.backbone_model)
self.mark_as_delta()
def set_batchsplit_pattern(self,
pattern: List,
):
r"""Set the batch split pattern.
Args:
pattern (:obj:`List`): The batch split pattern.
"""
for module_name, layer_list in self.batch_layer.items():
for batch_layer in layer_list:
batch_layer.set_batchsplit_pattern(pattern)
def split_attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, **kwargs):
if module_name in self.modified_points:
raise ValueError(f"{module_name} already in delta model")
if isinstance(modified_point, str):
modified_point = [modified_point]
if module_type not in ["adapter", "lora"]:
raise ValueError("module_type must be either adapter or lora")
self.name2point[module_name] = (module_type, modified_point)
advailable = ["adapter", "lora"]
advailable += [f"batch_{a}" for a in advailable]
if module_type not in advailable:
raise ValueError(f"module_type must be in {' '.join(advailable)}.")
if module_type.startswith("batch_"):
self.batch_layer[module_name] = []
for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
@ -204,32 +328,42 @@ class SplitModel(DeltaBase):
if module_type == "adapter":
splitlayer = SplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
if module_type == "lora":
elif module_type == "lora":
splitlayer = SplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
elif module_type == "batch_adapter":
splitlayer = BatchSplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
elif module_type == "batch_lora":
splitlayer = BatchSplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
self.modified_points[key] = splitlayer
splitlayer = self.modified_points[key]
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)):
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
raise ValueError("one modified_point can have at most one module_type")
if module is None:
if module_type == "adapter":
module = self.new_adapter_like(ref)
if module_type == "lora":
module = self.new_lora_like(ref)
if module_type.startswith("batch_"):
self.batch_layer[module_name].append(splitlayer)
delta_type = module_type[6:]
else:
delta_type = module_type
if not splitlayer.attach(module_name, module):
if delta_type == "adapter":
module = self.new_adapter_like(ref, **kwargs)
elif delta_type == "lora":
module = self.new_lora_like(ref, **kwargs)
if not splitlayer.split_attach(module_name, module):
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
def update(self):
self._pseudo_data_to_instantiate(self.backbone_model)
self.mark_as_delta()
def detach(self, modified_point: str, module_name: str):
if isinstance(modified_point, str):
modified_point = [modified_point]
def split_detach(self, module_name: str):
if module_name not in self.name2point:
raise ValueError(f"{module_name} not in delta model")
module_type, modified_point = self.name2point.pop(module_name)
for key in self.backbone_key_list:
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
@ -241,23 +375,105 @@ class SplitModel(DeltaBase):
splitlayer = self.modified_points[key]
if not splitlayer.detach(module_name):
module = splitlayer.split_detach(module_name)
if module is None:
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
def new_adapter_like(self, module):
adapterlayer = AdapterLayer()
if module_type.startswith("batch_"):
self.batch_layer.pop(module_name)
def save_split(self, module_name: str, save_name: str):
if module_name not in self.name2point:
raise ValueError(f"{module_name} not in delta model")
module_type, modified_point = self.name2point[module_name]
print("Save", module_name, modified_point)
module_dict = nn.ModuleDict()
for key in self.backbone_key_list:
print("find", key, modified_point, self.find_key(key, modified_point))
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
if key not in self.modified_points:
raise ValueError("no module has been added to {}".format(key))
splitlayer = self.modified_points[key]
module = splitlayer.split_get(module_name)
if module is None:
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
module_dict[f"{module_type}:{key.replace('.', ':')}"] = module
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
torch.save(module_dict, save_name)
def load_split(self, module_name: str, load_name: str):
if module_name in self.modified_points:
raise ValueError(f"{module_name} already in delta model")
module_dict = torch.load(load_name)
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
keys = [key.split(':',maxsplit=1)[1].replace(':', '.') for key in module_dict.keys()]
module_types = [key.split(':',maxsplit=1)[0] for key in module_dict.keys()]
module_type = module_types[0]
print(keys)
print(module_type)
self.name2point[module_name] = (module_type, keys)
if module_types[0].startswith("batch_"):
self.batch_layer[module_name] = []
for key in self.backbone_key_list:
if key in keys:
logger.debug("find key: {}".format(key))
_, _, ref = self.find_module(self.backbone_model, key)
module = module_dict[list(module_dict.keys())[keys.index(key)]]
if key not in self.modified_points:
if module_type == "adapter":
splitlayer = SplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
elif module_type == "lora":
splitlayer = SplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
elif module_type == "batch_adapter":
splitlayer = BatchSplitSequentialLayer()
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
elif module_type == "batch_lora":
splitlayer = BatchSplitParallelLayer()
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
self.modified_points[key] = splitlayer
splitlayer = self.modified_points[key]
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
raise ValueError("one modified_point can have at most one module_type")
if module_type.startswith("batch_"):
self.batch_layer[module_name].append(splitlayer)
if not splitlayer.split_attach(module_name, module):
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
def new_adapter_like(self, module, **kwargs):
adapterlayer = AdapterLayer(**kwargs)
self.delta_modules.append(adapterlayer)
return adapterlayer
def new_lora_like(self, child_module):
def new_lora_like(self, child_module, **kwargs):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = LowRankLinear(in_features = in_features,
out_features = out_features,
weight = child_module.weight,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
**kwargs,)
self.delta_modules.append(new_module)
else:
raise NotImplementedError

26
test.py
View File

@ -1,26 +0,0 @@
from transformers import BertModel
model = BertModel.from_pretrained("bert-base-cased")
from opendelta import Visualization
Visualization(model).structure_graph()
from opendelta import SplitModel
delta_model = SplitModel(model)
print("attach here")
delta_model.attach("output.dense", "adapter_A", "adapter")
delta_model.attach("output.dense", "adapter_B", "adapter")
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
delta_model.update()
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
delta_model.log() # This will visualize the backbone after modification and other information.
import torch
x = torch.randint(0, 10, (16, 128)).cuda()
import time
model = model.cuda()
st_time = time.time()
print("run here")
y = model(x)
print("detach here")
delta_model.detach("3.output.dense", "adapter_A")
delta_model.log()
print("run here")
y = model(x)
print(time.time() - st_time)