fix split and add batchsplit
This commit is contained in:
parent
f8db5be89b
commit
5e2c8e4c83
|
@ -0,0 +1,43 @@
|
||||||
|
from transformers import BertModel
|
||||||
|
model = BertModel.from_pretrained("bert-base-cased")
|
||||||
|
from opendelta import Visualization
|
||||||
|
Visualization(model).structure_graph()
|
||||||
|
from opendelta import SplitModel
|
||||||
|
delta_model = SplitModel(model)
|
||||||
|
|
||||||
|
print("split_attach here")
|
||||||
|
delta_model.split_attach("output.dense", "adapter_A", "adapter", bottleneck_dim=12)
|
||||||
|
delta_model.split_attach("output.dense", "adapter_B", "adapter", bottleneck_dim=16, non_linearity="relu")
|
||||||
|
# delta_model.split_attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
|
||||||
|
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "lora", r=8)
|
||||||
|
delta_model.update()
|
||||||
|
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||||
|
|
||||||
|
print("batchsplit_attach here")
|
||||||
|
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "batch_lora", r=4)
|
||||||
|
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_B", "batch_lora", r=8)
|
||||||
|
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_E", "batch_adapter")
|
||||||
|
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_F", "batch_adapter")
|
||||||
|
delta_model.update()
|
||||||
|
delta_model.log()
|
||||||
|
|
||||||
|
print("split_detach and save here")
|
||||||
|
delta_model.save_split("adapter_A", "adapter_A_split.pt")
|
||||||
|
delta_model.split_detach("adapter_A")
|
||||||
|
delta_model.save_split("lora_A", "lora_A_split.pt")
|
||||||
|
delta_model.split_detach("lora_A")
|
||||||
|
delta_model.update()
|
||||||
|
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||||
|
|
||||||
|
print("load back here")
|
||||||
|
delta_model.load_split("adapter_A", "adapter_A_split.pt")
|
||||||
|
delta_model.load_split("lora_A", "lora_A_split.pt")
|
||||||
|
delta_model.update()
|
||||||
|
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||||
|
|
||||||
|
print("run here")
|
||||||
|
import torch
|
||||||
|
x = torch.randint(0, 10, (16, 128)).cuda()
|
||||||
|
delta_model.set_batchsplit_pattern(['lora_A']*4 + ['lora_B']*12)
|
||||||
|
model = model.cuda()
|
||||||
|
y = model(x)
|
|
@ -5,6 +5,16 @@ import torch.nn as nn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.activations import get_activation
|
from transformers.activations import get_activation
|
||||||
|
|
||||||
|
def swish(x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
def gelu_new(x):
|
||||||
|
"""
|
||||||
|
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
class Activations(nn.Module):
|
class Activations(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink
|
Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink
|
||||||
|
@ -17,20 +27,8 @@ class Activations(nn.Module):
|
||||||
elif activation_type.lower() == "tanh":
|
elif activation_type.lower() == "tanh":
|
||||||
self.f = torch.tanh
|
self.f = torch.tanh
|
||||||
elif activation_type.lower() == "swish":
|
elif activation_type.lower() == "swish":
|
||||||
|
|
||||||
def swish(x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
self.f = swish
|
self.f = swish
|
||||||
elif activation_type.lower() == "gelu_new":
|
elif activation_type.lower() == "gelu_new":
|
||||||
|
|
||||||
def gelu_new(x):
|
|
||||||
"""
|
|
||||||
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
|
||||||
Also see https://arxiv.org/abs/1606.08415
|
|
||||||
"""
|
|
||||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
|
||||||
|
|
||||||
self.f = gelu_new
|
self.f = gelu_new
|
||||||
elif activation_type.lower() == "gelu_orig":
|
elif activation_type.lower() == "gelu_orig":
|
||||||
self.f = nn.functional.gelu
|
self.f = nn.functional.gelu
|
||||||
|
|
|
@ -9,6 +9,10 @@ import torch.nn as nn
|
||||||
from opendelta import BaseDeltaConfig
|
from opendelta import BaseDeltaConfig
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
class Identical(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
class LowRankLinear(nn.Module):
|
class LowRankLinear(nn.Module):
|
||||||
# ------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------
|
||||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
@ -30,7 +34,7 @@ class LowRankLinear(nn.Module):
|
||||||
if lora_dropout > 0.:
|
if lora_dropout > 0.:
|
||||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||||
else:
|
else:
|
||||||
self.lora_dropout = lambda x: x
|
self.lora_dropout = Identical()
|
||||||
if r > 0:
|
if r > 0:
|
||||||
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
||||||
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
|
from html.entities import name2codepoint
|
||||||
from random import random
|
from random import random
|
||||||
from sqlite3 import adapters
|
from sqlite3 import adapters
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from cv2 import accumulate
|
from opendelta.utils.signature import get_arg_names_inside_func, signature
|
||||||
from opendelta.utils.signature import get_arg_names_inside_func
|
|
||||||
from opendelta.utils.name_based_addressing import *
|
from opendelta.utils.name_based_addressing import *
|
||||||
from opendelta.utils.cuda import get_device
|
from opendelta.utils.cuda import get_device
|
||||||
from opendelta.basemodel import DeltaBase
|
from opendelta.basemodel import DeltaBase
|
||||||
import loralib as lora
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
@ -20,7 +19,6 @@ import opendelta.utils.logging as logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from opendelta import global_setting
|
from opendelta import global_setting
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
from itertools import accumulate
|
|
||||||
from opendelta.delta_models.adapter import AdapterLayer
|
from opendelta.delta_models.adapter import AdapterLayer
|
||||||
from opendelta.delta_models.lora import LowRankLinear
|
from opendelta.delta_models.lora import LowRankLinear
|
||||||
|
|
||||||
|
@ -31,17 +29,44 @@ class _SplitLayer(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module_dict = nn.ModuleDict()
|
self.module_dict = nn.ModuleDict()
|
||||||
|
|
||||||
def attach(self, module_name: str, module: nn.Module):
|
def split_attach(self, module_name: str, module: nn.Module):
|
||||||
if module_name in self.module_dict:
|
if module_name in self.module_dict:
|
||||||
return False
|
return False
|
||||||
self.module_dict[module_name] = module
|
self.module_dict[module_name] = module
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def detach(self, module_name: str):
|
def split_detach(self, module_name: str):
|
||||||
if module_name not in self.module_dict:
|
if module_name not in self.module_dict:
|
||||||
return False
|
return None
|
||||||
self.module_dict.pop(module_name)
|
return self.module_dict.pop(module_name)
|
||||||
return True
|
|
||||||
|
def split_get(self, module_name: str):
|
||||||
|
if module_name not in self.module_dict:
|
||||||
|
return None
|
||||||
|
return self.module_dict[module_name]
|
||||||
|
|
||||||
|
class _BatchSplitLayer(_SplitLayer):
|
||||||
|
r"""A layer of batch splitting module.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.split_pattern = {}
|
||||||
|
|
||||||
|
def set_batchsplit_pattern(self, list_pattern):
|
||||||
|
self.split_pattern = {}
|
||||||
|
self.list_pattern = list_pattern
|
||||||
|
self.count = []
|
||||||
|
for i, name in enumerate(list_pattern):
|
||||||
|
if name not in self.split_pattern:
|
||||||
|
self.split_pattern[name] = []
|
||||||
|
self.count.append(len(self.split_pattern[name]))
|
||||||
|
self.split_pattern[name].append(i)
|
||||||
|
|
||||||
|
def get_batchsplit_pattern(self, name):
|
||||||
|
return self.split_pattern.get(name, None)
|
||||||
|
|
||||||
|
def merge_by_pattern(self, output_dict):
|
||||||
|
return torch.stack([output_dict[name][self.count[i]] for i, name in enumerate(self.list_pattern)], dim=0)
|
||||||
|
|
||||||
class SplitSequentialLayer(_SplitLayer):
|
class SplitSequentialLayer(_SplitLayer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -60,7 +85,7 @@ class SplitSequentialLayer(_SplitLayer):
|
||||||
split_outputs.append( module.post_forward(
|
split_outputs.append( module.post_forward(
|
||||||
hiddens
|
hiddens
|
||||||
) )
|
) )
|
||||||
print(len(split_outputs))
|
print("sequential", len(split_outputs))
|
||||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
|
@ -69,11 +94,11 @@ class SplitSequentialLayer(_SplitLayer):
|
||||||
output = merge_output
|
output = merge_output
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
print(hiddens.shape)
|
||||||
|
print(merge_output.shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class SplitParallelLayer(_SplitLayer):
|
class SplitParallelLayer(_SplitLayer):
|
||||||
r"""A layer of splitting module.
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -83,10 +108,58 @@ class SplitParallelLayer(_SplitLayer):
|
||||||
split_outputs.append( module(
|
split_outputs.append( module(
|
||||||
hiddens
|
hiddens
|
||||||
) )
|
) )
|
||||||
print(len(split_outputs))
|
print("paralell", len(split_outputs))
|
||||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||||
return merge_output
|
return merge_output
|
||||||
|
|
||||||
|
class BatchSplitSequentialLayer(_BatchSplitLayer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def post_forward(self, output):
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
hiddens = output[0]
|
||||||
|
elif isinstance(output, torch.Tensor):
|
||||||
|
hiddens = output
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
split_outputs = {}
|
||||||
|
for module_name, module in self.module_dict.items():
|
||||||
|
pattern = self.get_batchsplit_pattern(module_name)
|
||||||
|
if pattern is not None:
|
||||||
|
split_outputs[module_name] = module.post_forward(
|
||||||
|
hiddens[pattern],
|
||||||
|
)
|
||||||
|
merge_output = self.merge_by_pattern(split_outputs)
|
||||||
|
print(hiddens.shape)
|
||||||
|
print(merge_output.shape)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (merge_output,) + output[1:]
|
||||||
|
elif isinstance(output, torch.Tensor):
|
||||||
|
output = merge_output
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
return output
|
||||||
|
|
||||||
|
class BatchSplitParallelLayer(_BatchSplitLayer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, hiddens):
|
||||||
|
split_outputs = {}
|
||||||
|
for module_name, module in self.module_dict.items():
|
||||||
|
pattern = self.get_batchsplit_pattern(module_name)
|
||||||
|
if pattern != None:
|
||||||
|
split_outputs[module_name] = module(
|
||||||
|
hiddens[pattern],
|
||||||
|
)
|
||||||
|
merge_output = self.merge_by_pattern(split_outputs)
|
||||||
|
print(hiddens.shape)
|
||||||
|
print(merge_output.shape)
|
||||||
|
return merge_output
|
||||||
|
|
||||||
class SplitConfig(BaseDeltaConfig):
|
class SplitConfig(BaseDeltaConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
|
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
|
||||||
|
@ -134,7 +207,7 @@ class SplitModel(DeltaBase):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_class = SplitConfig
|
config_class = SplitConfig
|
||||||
delta_type = "batch_split"
|
delta_type = "split"
|
||||||
default_modified_modules = ["attn", "ff"]
|
default_modified_modules = ["attn", "ff"]
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
backbone_model: nn.Module,
|
backbone_model: nn.Module,
|
||||||
|
@ -163,6 +236,8 @@ class SplitModel(DeltaBase):
|
||||||
self.modified_modules,
|
self.modified_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.name2point = {}
|
||||||
|
self.batch_layer = {}
|
||||||
self.modified_points = {}
|
self.modified_points = {}
|
||||||
|
|
||||||
def add_all_delta_to_backbone(self,
|
def add_all_delta_to_backbone(self,
|
||||||
|
@ -188,12 +263,61 @@ class SplitModel(DeltaBase):
|
||||||
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
|
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
|
||||||
return backbone
|
return backbone
|
||||||
|
|
||||||
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None):
|
def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
|
||||||
|
r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph.
|
||||||
|
First try to use the dummy_inputs of the pretrained model. If the model has no dummy_inputs, will try to create
|
||||||
|
integer tensor as the pseudo_input, if ``decoder_input_ids`` is in the model's forward function, additional create it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = get_device(module)
|
||||||
|
logger.warning("No dummy_inputs attributes, create a common input_ids for input.")
|
||||||
|
if len(self.batch_layer) > 0:
|
||||||
|
pseudo_input = torch.tensor([[0,0,0,0]]*len(self.batch_layer)).to(device)
|
||||||
|
self.set_batchsplit_pattern(list(self.batch_layer.keys()))
|
||||||
|
else:
|
||||||
|
pseudo_input = torch.tensor([[0,0,0,0]]).to(device)
|
||||||
|
print(pseudo_input)
|
||||||
|
if "decoder_input_ids" in signature(module.forward).args:
|
||||||
|
module(pseudo_input, decoder_input_ids = pseudo_input)
|
||||||
|
else:
|
||||||
|
module(pseudo_input)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
self._pseudo_data_to_instantiate(self.backbone_model)
|
||||||
|
self.mark_as_delta()
|
||||||
|
|
||||||
|
def set_batchsplit_pattern(self,
|
||||||
|
pattern: List,
|
||||||
|
):
|
||||||
|
r"""Set the batch split pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern (:obj:`List`): The batch split pattern.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for module_name, layer_list in self.batch_layer.items():
|
||||||
|
for batch_layer in layer_list:
|
||||||
|
batch_layer.set_batchsplit_pattern(pattern)
|
||||||
|
|
||||||
|
def split_attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, **kwargs):
|
||||||
|
if module_name in self.modified_points:
|
||||||
|
raise ValueError(f"{module_name} already in delta model")
|
||||||
|
|
||||||
if isinstance(modified_point, str):
|
if isinstance(modified_point, str):
|
||||||
modified_point = [modified_point]
|
modified_point = [modified_point]
|
||||||
|
|
||||||
if module_type not in ["adapter", "lora"]:
|
self.name2point[module_name] = (module_type, modified_point)
|
||||||
raise ValueError("module_type must be either adapter or lora")
|
|
||||||
|
advailable = ["adapter", "lora"]
|
||||||
|
advailable += [f"batch_{a}" for a in advailable]
|
||||||
|
if module_type not in advailable:
|
||||||
|
raise ValueError(f"module_type must be in {' '.join(advailable)}.")
|
||||||
|
|
||||||
|
if module_type.startswith("batch_"):
|
||||||
|
self.batch_layer[module_name] = []
|
||||||
|
|
||||||
for key in self.backbone_key_list:
|
for key in self.backbone_key_list:
|
||||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||||
|
@ -204,32 +328,42 @@ class SplitModel(DeltaBase):
|
||||||
if module_type == "adapter":
|
if module_type == "adapter":
|
||||||
splitlayer = SplitSequentialLayer()
|
splitlayer = SplitSequentialLayer()
|
||||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
||||||
if module_type == "lora":
|
elif module_type == "lora":
|
||||||
splitlayer = SplitParallelLayer()
|
splitlayer = SplitParallelLayer()
|
||||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
||||||
|
elif module_type == "batch_adapter":
|
||||||
|
splitlayer = BatchSplitSequentialLayer()
|
||||||
|
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
|
||||||
|
elif module_type == "batch_lora":
|
||||||
|
splitlayer = BatchSplitParallelLayer()
|
||||||
|
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
|
||||||
self.modified_points[key] = splitlayer
|
self.modified_points[key] = splitlayer
|
||||||
|
|
||||||
splitlayer = self.modified_points[key]
|
splitlayer = self.modified_points[key]
|
||||||
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
||||||
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)):
|
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
|
||||||
|
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
|
||||||
|
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
|
||||||
raise ValueError("one modified_point can have at most one module_type")
|
raise ValueError("one modified_point can have at most one module_type")
|
||||||
|
|
||||||
if module is None:
|
if module_type.startswith("batch_"):
|
||||||
if module_type == "adapter":
|
self.batch_layer[module_name].append(splitlayer)
|
||||||
module = self.new_adapter_like(ref)
|
delta_type = module_type[6:]
|
||||||
if module_type == "lora":
|
else:
|
||||||
module = self.new_lora_like(ref)
|
delta_type = module_type
|
||||||
|
|
||||||
if not splitlayer.attach(module_name, module):
|
if delta_type == "adapter":
|
||||||
|
module = self.new_adapter_like(ref, **kwargs)
|
||||||
|
elif delta_type == "lora":
|
||||||
|
module = self.new_lora_like(ref, **kwargs)
|
||||||
|
|
||||||
|
if not splitlayer.split_attach(module_name, module):
|
||||||
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
||||||
|
|
||||||
def update(self):
|
def split_detach(self, module_name: str):
|
||||||
self._pseudo_data_to_instantiate(self.backbone_model)
|
if module_name not in self.name2point:
|
||||||
self.mark_as_delta()
|
raise ValueError(f"{module_name} not in delta model")
|
||||||
|
module_type, modified_point = self.name2point.pop(module_name)
|
||||||
def detach(self, modified_point: str, module_name: str):
|
|
||||||
if isinstance(modified_point, str):
|
|
||||||
modified_point = [modified_point]
|
|
||||||
|
|
||||||
for key in self.backbone_key_list:
|
for key in self.backbone_key_list:
|
||||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||||
|
@ -241,23 +375,105 @@ class SplitModel(DeltaBase):
|
||||||
|
|
||||||
splitlayer = self.modified_points[key]
|
splitlayer = self.modified_points[key]
|
||||||
|
|
||||||
if not splitlayer.detach(module_name):
|
module = splitlayer.split_detach(module_name)
|
||||||
|
if module is None:
|
||||||
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
||||||
|
|
||||||
def new_adapter_like(self, module):
|
if module_type.startswith("batch_"):
|
||||||
adapterlayer = AdapterLayer()
|
self.batch_layer.pop(module_name)
|
||||||
|
|
||||||
|
def save_split(self, module_name: str, save_name: str):
|
||||||
|
if module_name not in self.name2point:
|
||||||
|
raise ValueError(f"{module_name} not in delta model")
|
||||||
|
module_type, modified_point = self.name2point[module_name]
|
||||||
|
print("Save", module_name, modified_point)
|
||||||
|
|
||||||
|
module_dict = nn.ModuleDict()
|
||||||
|
|
||||||
|
for key in self.backbone_key_list:
|
||||||
|
print("find", key, modified_point, self.find_key(key, modified_point))
|
||||||
|
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||||
|
logger.debug("find key: {}".format(key))
|
||||||
|
_, _, ref = self.find_module(self.backbone_model, key)
|
||||||
|
|
||||||
|
if key not in self.modified_points:
|
||||||
|
raise ValueError("no module has been added to {}".format(key))
|
||||||
|
|
||||||
|
splitlayer = self.modified_points[key]
|
||||||
|
|
||||||
|
module = splitlayer.split_get(module_name)
|
||||||
|
if module is None:
|
||||||
|
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
||||||
|
|
||||||
|
module_dict[f"{module_type}:{key.replace('.', ':')}"] = module
|
||||||
|
|
||||||
|
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
|
||||||
|
torch.save(module_dict, save_name)
|
||||||
|
|
||||||
|
def load_split(self, module_name: str, load_name: str):
|
||||||
|
if module_name in self.modified_points:
|
||||||
|
raise ValueError(f"{module_name} already in delta model")
|
||||||
|
|
||||||
|
module_dict = torch.load(load_name)
|
||||||
|
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
|
||||||
|
|
||||||
|
keys = [key.split(':',maxsplit=1)[1].replace(':', '.') for key in module_dict.keys()]
|
||||||
|
module_types = [key.split(':',maxsplit=1)[0] for key in module_dict.keys()]
|
||||||
|
module_type = module_types[0]
|
||||||
|
print(keys)
|
||||||
|
print(module_type)
|
||||||
|
|
||||||
|
self.name2point[module_name] = (module_type, keys)
|
||||||
|
|
||||||
|
if module_types[0].startswith("batch_"):
|
||||||
|
self.batch_layer[module_name] = []
|
||||||
|
|
||||||
|
for key in self.backbone_key_list:
|
||||||
|
if key in keys:
|
||||||
|
logger.debug("find key: {}".format(key))
|
||||||
|
_, _, ref = self.find_module(self.backbone_model, key)
|
||||||
|
module = module_dict[list(module_dict.keys())[keys.index(key)]]
|
||||||
|
|
||||||
|
if key not in self.modified_points:
|
||||||
|
if module_type == "adapter":
|
||||||
|
splitlayer = SplitSequentialLayer()
|
||||||
|
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
||||||
|
elif module_type == "lora":
|
||||||
|
splitlayer = SplitParallelLayer()
|
||||||
|
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
||||||
|
elif module_type == "batch_adapter":
|
||||||
|
splitlayer = BatchSplitSequentialLayer()
|
||||||
|
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
|
||||||
|
elif module_type == "batch_lora":
|
||||||
|
splitlayer = BatchSplitParallelLayer()
|
||||||
|
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
|
||||||
|
self.modified_points[key] = splitlayer
|
||||||
|
|
||||||
|
splitlayer = self.modified_points[key]
|
||||||
|
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
||||||
|
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
|
||||||
|
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
|
||||||
|
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
|
||||||
|
raise ValueError("one modified_point can have at most one module_type")
|
||||||
|
|
||||||
|
if module_type.startswith("batch_"):
|
||||||
|
self.batch_layer[module_name].append(splitlayer)
|
||||||
|
|
||||||
|
if not splitlayer.split_attach(module_name, module):
|
||||||
|
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
||||||
|
|
||||||
|
def new_adapter_like(self, module, **kwargs):
|
||||||
|
adapterlayer = AdapterLayer(**kwargs)
|
||||||
self.delta_modules.append(adapterlayer)
|
self.delta_modules.append(adapterlayer)
|
||||||
return adapterlayer
|
return adapterlayer
|
||||||
|
|
||||||
def new_lora_like(self, child_module):
|
def new_lora_like(self, child_module, **kwargs):
|
||||||
if isinstance(child_module, nn.Linear):
|
if isinstance(child_module, nn.Linear):
|
||||||
in_features, out_features = child_module.in_features, child_module.out_features
|
in_features, out_features = child_module.in_features, child_module.out_features
|
||||||
new_module = LowRankLinear(in_features = in_features,
|
new_module = LowRankLinear(in_features = in_features,
|
||||||
out_features = out_features,
|
out_features = out_features,
|
||||||
weight = child_module.weight,
|
weight = child_module.weight,
|
||||||
r=self.lora_r,
|
**kwargs,)
|
||||||
lora_alpha=self.lora_alpha,
|
|
||||||
lora_dropout=self.lora_dropout)
|
|
||||||
self.delta_modules.append(new_module)
|
self.delta_modules.append(new_module)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
26
test.py
26
test.py
|
@ -1,26 +0,0 @@
|
||||||
from transformers import BertModel
|
|
||||||
model = BertModel.from_pretrained("bert-base-cased")
|
|
||||||
from opendelta import Visualization
|
|
||||||
Visualization(model).structure_graph()
|
|
||||||
from opendelta import SplitModel
|
|
||||||
delta_model = SplitModel(model)
|
|
||||||
print("attach here")
|
|
||||||
delta_model.attach("output.dense", "adapter_A", "adapter")
|
|
||||||
delta_model.attach("output.dense", "adapter_B", "adapter")
|
|
||||||
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
|
|
||||||
delta_model.update()
|
|
||||||
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
|
|
||||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
|
||||||
import torch
|
|
||||||
x = torch.randint(0, 10, (16, 128)).cuda()
|
|
||||||
import time
|
|
||||||
model = model.cuda()
|
|
||||||
st_time = time.time()
|
|
||||||
print("run here")
|
|
||||||
y = model(x)
|
|
||||||
print("detach here")
|
|
||||||
delta_model.detach("3.output.dense", "adapter_A")
|
|
||||||
delta_model.log()
|
|
||||||
print("run here")
|
|
||||||
y = model(x)
|
|
||||||
print(time.time() - st_time)
|
|
Loading…
Reference in New Issue