fix split and add batchsplit
This commit is contained in:
parent
f8db5be89b
commit
5e2c8e4c83
|
@ -0,0 +1,43 @@
|
|||
from transformers import BertModel
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
from opendelta import Visualization
|
||||
Visualization(model).structure_graph()
|
||||
from opendelta import SplitModel
|
||||
delta_model = SplitModel(model)
|
||||
|
||||
print("split_attach here")
|
||||
delta_model.split_attach("output.dense", "adapter_A", "adapter", bottleneck_dim=12)
|
||||
delta_model.split_attach("output.dense", "adapter_B", "adapter", bottleneck_dim=16, non_linearity="relu")
|
||||
# delta_model.split_attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
|
||||
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "lora", r=8)
|
||||
delta_model.update()
|
||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
|
||||
print("batchsplit_attach here")
|
||||
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_A", "batch_lora", r=4)
|
||||
delta_model.split_attach(["attention.self.query", "attention.self.key"], "lora_B", "batch_lora", r=8)
|
||||
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_E", "batch_adapter")
|
||||
# delta_model.split_attach(["attention.self.query", "attention.self.key"], "adapter_F", "batch_adapter")
|
||||
delta_model.update()
|
||||
delta_model.log()
|
||||
|
||||
print("split_detach and save here")
|
||||
delta_model.save_split("adapter_A", "adapter_A_split.pt")
|
||||
delta_model.split_detach("adapter_A")
|
||||
delta_model.save_split("lora_A", "lora_A_split.pt")
|
||||
delta_model.split_detach("lora_A")
|
||||
delta_model.update()
|
||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
|
||||
print("load back here")
|
||||
delta_model.load_split("adapter_A", "adapter_A_split.pt")
|
||||
delta_model.load_split("lora_A", "lora_A_split.pt")
|
||||
delta_model.update()
|
||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
|
||||
print("run here")
|
||||
import torch
|
||||
x = torch.randint(0, 10, (16, 128)).cuda()
|
||||
delta_model.set_batchsplit_pattern(['lora_A']*4 + ['lora_B']*12)
|
||||
model = model.cuda()
|
||||
y = model(x)
|
|
@ -5,6 +5,16 @@ import torch.nn as nn
|
|||
import torch.nn as nn
|
||||
from transformers.activations import get_activation
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
def gelu_new(x):
|
||||
"""
|
||||
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
class Activations(nn.Module):
|
||||
"""
|
||||
Implementation of various activation function. Copied from open-source project AdapterHub #TODO: addlink
|
||||
|
@ -17,20 +27,8 @@ class Activations(nn.Module):
|
|||
elif activation_type.lower() == "tanh":
|
||||
self.f = torch.tanh
|
||||
elif activation_type.lower() == "swish":
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
self.f = swish
|
||||
elif activation_type.lower() == "gelu_new":
|
||||
|
||||
def gelu_new(x):
|
||||
"""
|
||||
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
self.f = gelu_new
|
||||
elif activation_type.lower() == "gelu_orig":
|
||||
self.f = nn.functional.gelu
|
||||
|
|
|
@ -9,6 +9,10 @@ import torch.nn as nn
|
|||
from opendelta import BaseDeltaConfig
|
||||
import math
|
||||
|
||||
class Identical(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
class LowRankLinear(nn.Module):
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
@ -30,7 +34,7 @@ class LowRankLinear(nn.Module):
|
|||
if lora_dropout > 0.:
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
self.lora_dropout = lambda x: x
|
||||
self.lora_dropout = Identical()
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from functools import partial
|
||||
from hashlib import sha1
|
||||
from html.entities import name2codepoint
|
||||
from random import random
|
||||
from sqlite3 import adapters
|
||||
from typing import Optional, Union
|
||||
|
||||
from cv2 import accumulate
|
||||
from opendelta.utils.signature import get_arg_names_inside_func
|
||||
from opendelta.utils.signature import get_arg_names_inside_func, signature
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.utils.cuda import get_device
|
||||
from opendelta.basemodel import DeltaBase
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
|
@ -20,7 +19,6 @@ import opendelta.utils.logging as logging
|
|||
import numpy as np
|
||||
from opendelta import global_setting
|
||||
logger = logging.get_logger(__name__)
|
||||
from itertools import accumulate
|
||||
from opendelta.delta_models.adapter import AdapterLayer
|
||||
from opendelta.delta_models.lora import LowRankLinear
|
||||
|
||||
|
@ -31,17 +29,44 @@ class _SplitLayer(nn.Module):
|
|||
super().__init__()
|
||||
self.module_dict = nn.ModuleDict()
|
||||
|
||||
def attach(self, module_name: str, module: nn.Module):
|
||||
def split_attach(self, module_name: str, module: nn.Module):
|
||||
if module_name in self.module_dict:
|
||||
return False
|
||||
self.module_dict[module_name] = module
|
||||
return True
|
||||
|
||||
def detach(self, module_name: str):
|
||||
def split_detach(self, module_name: str):
|
||||
if module_name not in self.module_dict:
|
||||
return False
|
||||
self.module_dict.pop(module_name)
|
||||
return True
|
||||
return None
|
||||
return self.module_dict.pop(module_name)
|
||||
|
||||
def split_get(self, module_name: str):
|
||||
if module_name not in self.module_dict:
|
||||
return None
|
||||
return self.module_dict[module_name]
|
||||
|
||||
class _BatchSplitLayer(_SplitLayer):
|
||||
r"""A layer of batch splitting module.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.split_pattern = {}
|
||||
|
||||
def set_batchsplit_pattern(self, list_pattern):
|
||||
self.split_pattern = {}
|
||||
self.list_pattern = list_pattern
|
||||
self.count = []
|
||||
for i, name in enumerate(list_pattern):
|
||||
if name not in self.split_pattern:
|
||||
self.split_pattern[name] = []
|
||||
self.count.append(len(self.split_pattern[name]))
|
||||
self.split_pattern[name].append(i)
|
||||
|
||||
def get_batchsplit_pattern(self, name):
|
||||
return self.split_pattern.get(name, None)
|
||||
|
||||
def merge_by_pattern(self, output_dict):
|
||||
return torch.stack([output_dict[name][self.count[i]] for i, name in enumerate(self.list_pattern)], dim=0)
|
||||
|
||||
class SplitSequentialLayer(_SplitLayer):
|
||||
def __init__(self):
|
||||
|
@ -60,7 +85,7 @@ class SplitSequentialLayer(_SplitLayer):
|
|||
split_outputs.append( module.post_forward(
|
||||
hiddens
|
||||
) )
|
||||
print(len(split_outputs))
|
||||
print("sequential", len(split_outputs))
|
||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
|
@ -69,11 +94,11 @@ class SplitSequentialLayer(_SplitLayer):
|
|||
output = merge_output
|
||||
else:
|
||||
raise TypeError
|
||||
print(hiddens.shape)
|
||||
print(merge_output.shape)
|
||||
return output
|
||||
|
||||
class SplitParallelLayer(_SplitLayer):
|
||||
r"""A layer of splitting module.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -83,10 +108,58 @@ class SplitParallelLayer(_SplitLayer):
|
|||
split_outputs.append( module(
|
||||
hiddens
|
||||
) )
|
||||
print(len(split_outputs))
|
||||
print("paralell", len(split_outputs))
|
||||
merge_output = torch.sum(torch.stack(split_outputs, dim=0), dim=0)
|
||||
return merge_output
|
||||
|
||||
class BatchSplitSequentialLayer(_BatchSplitLayer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def post_forward(self, output):
|
||||
if isinstance(output, tuple):
|
||||
hiddens = output[0]
|
||||
elif isinstance(output, torch.Tensor):
|
||||
hiddens = output
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
split_outputs = {}
|
||||
for module_name, module in self.module_dict.items():
|
||||
pattern = self.get_batchsplit_pattern(module_name)
|
||||
if pattern is not None:
|
||||
split_outputs[module_name] = module.post_forward(
|
||||
hiddens[pattern],
|
||||
)
|
||||
merge_output = self.merge_by_pattern(split_outputs)
|
||||
print(hiddens.shape)
|
||||
print(merge_output.shape)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (merge_output,) + output[1:]
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output = merge_output
|
||||
else:
|
||||
raise TypeError
|
||||
return output
|
||||
|
||||
class BatchSplitParallelLayer(_BatchSplitLayer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hiddens):
|
||||
split_outputs = {}
|
||||
for module_name, module in self.module_dict.items():
|
||||
pattern = self.get_batchsplit_pattern(module_name)
|
||||
if pattern != None:
|
||||
split_outputs[module_name] = module(
|
||||
hiddens[pattern],
|
||||
)
|
||||
merge_output = self.merge_by_pattern(split_outputs)
|
||||
print(hiddens.shape)
|
||||
print(merge_output.shape)
|
||||
return merge_output
|
||||
|
||||
class SplitConfig(BaseDeltaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a :py:class:`~SplitModel`
|
||||
|
@ -134,7 +207,7 @@ class SplitModel(DeltaBase):
|
|||
|
||||
"""
|
||||
config_class = SplitConfig
|
||||
delta_type = "batch_split"
|
||||
delta_type = "split"
|
||||
default_modified_modules = ["attn", "ff"]
|
||||
def __init__(self,
|
||||
backbone_model: nn.Module,
|
||||
|
@ -163,6 +236,8 @@ class SplitModel(DeltaBase):
|
|||
self.modified_modules,
|
||||
)
|
||||
|
||||
self.name2point = {}
|
||||
self.batch_layer = {}
|
||||
self.modified_points = {}
|
||||
|
||||
def add_all_delta_to_backbone(self,
|
||||
|
@ -188,12 +263,61 @@ class SplitModel(DeltaBase):
|
|||
self.backbone_key_list = [key for key, _ in backbone.named_modules()]
|
||||
return backbone
|
||||
|
||||
def attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, module: Optional[nn.Module] = None):
|
||||
def _pseudo_data_to_instantiate(self, module: Optional[nn.Module]=None):
|
||||
r"""Create a pseudo_data into the module to know the dimemsion of each tensor in the computation graph.
|
||||
First try to use the dummy_inputs of the pretrained model. If the model has no dummy_inputs, will try to create
|
||||
integer tensor as the pseudo_input, if ``decoder_input_ids`` is in the model's forward function, additional create it.
|
||||
|
||||
Args:
|
||||
module (:obj:`nn.Module`, *optional*, default to :obj:`None`): The backbone model.
|
||||
|
||||
"""
|
||||
device = get_device(module)
|
||||
logger.warning("No dummy_inputs attributes, create a common input_ids for input.")
|
||||
if len(self.batch_layer) > 0:
|
||||
pseudo_input = torch.tensor([[0,0,0,0]]*len(self.batch_layer)).to(device)
|
||||
self.set_batchsplit_pattern(list(self.batch_layer.keys()))
|
||||
else:
|
||||
pseudo_input = torch.tensor([[0,0,0,0]]).to(device)
|
||||
print(pseudo_input)
|
||||
if "decoder_input_ids" in signature(module.forward).args:
|
||||
module(pseudo_input, decoder_input_ids = pseudo_input)
|
||||
else:
|
||||
module(pseudo_input)
|
||||
|
||||
def update(self):
|
||||
self._pseudo_data_to_instantiate(self.backbone_model)
|
||||
self.mark_as_delta()
|
||||
|
||||
def set_batchsplit_pattern(self,
|
||||
pattern: List,
|
||||
):
|
||||
r"""Set the batch split pattern.
|
||||
|
||||
Args:
|
||||
pattern (:obj:`List`): The batch split pattern.
|
||||
|
||||
"""
|
||||
for module_name, layer_list in self.batch_layer.items():
|
||||
for batch_layer in layer_list:
|
||||
batch_layer.set_batchsplit_pattern(pattern)
|
||||
|
||||
def split_attach(self, modified_point: Union[str, List[str]], module_name: str, module_type: str, **kwargs):
|
||||
if module_name in self.modified_points:
|
||||
raise ValueError(f"{module_name} already in delta model")
|
||||
|
||||
if isinstance(modified_point, str):
|
||||
modified_point = [modified_point]
|
||||
|
||||
if module_type not in ["adapter", "lora"]:
|
||||
raise ValueError("module_type must be either adapter or lora")
|
||||
self.name2point[module_name] = (module_type, modified_point)
|
||||
|
||||
advailable = ["adapter", "lora"]
|
||||
advailable += [f"batch_{a}" for a in advailable]
|
||||
if module_type not in advailable:
|
||||
raise ValueError(f"module_type must be in {' '.join(advailable)}.")
|
||||
|
||||
if module_type.startswith("batch_"):
|
||||
self.batch_layer[module_name] = []
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
|
@ -204,32 +328,42 @@ class SplitModel(DeltaBase):
|
|||
if module_type == "adapter":
|
||||
splitlayer = SplitSequentialLayer()
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
||||
if module_type == "lora":
|
||||
elif module_type == "lora":
|
||||
splitlayer = SplitParallelLayer()
|
||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
||||
elif module_type == "batch_adapter":
|
||||
splitlayer = BatchSplitSequentialLayer()
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
|
||||
elif module_type == "batch_lora":
|
||||
splitlayer = BatchSplitParallelLayer()
|
||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
|
||||
self.modified_points[key] = splitlayer
|
||||
|
||||
splitlayer = self.modified_points[key]
|
||||
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
||||
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)):
|
||||
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
|
||||
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
|
||||
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
|
||||
raise ValueError("one modified_point can have at most one module_type")
|
||||
|
||||
if module is None:
|
||||
if module_type == "adapter":
|
||||
module = self.new_adapter_like(ref)
|
||||
if module_type == "lora":
|
||||
module = self.new_lora_like(ref)
|
||||
if module_type.startswith("batch_"):
|
||||
self.batch_layer[module_name].append(splitlayer)
|
||||
delta_type = module_type[6:]
|
||||
else:
|
||||
delta_type = module_type
|
||||
|
||||
if not splitlayer.attach(module_name, module):
|
||||
if delta_type == "adapter":
|
||||
module = self.new_adapter_like(ref, **kwargs)
|
||||
elif delta_type == "lora":
|
||||
module = self.new_lora_like(ref, **kwargs)
|
||||
|
||||
if not splitlayer.split_attach(module_name, module):
|
||||
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
def update(self):
|
||||
self._pseudo_data_to_instantiate(self.backbone_model)
|
||||
self.mark_as_delta()
|
||||
|
||||
def detach(self, modified_point: str, module_name: str):
|
||||
if isinstance(modified_point, str):
|
||||
modified_point = [modified_point]
|
||||
def split_detach(self, module_name: str):
|
||||
if module_name not in self.name2point:
|
||||
raise ValueError(f"{module_name} not in delta model")
|
||||
module_type, modified_point = self.name2point.pop(module_name)
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
|
@ -241,23 +375,105 @@ class SplitModel(DeltaBase):
|
|||
|
||||
splitlayer = self.modified_points[key]
|
||||
|
||||
if not splitlayer.detach(module_name):
|
||||
module = splitlayer.split_detach(module_name)
|
||||
if module is None:
|
||||
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
def new_adapter_like(self, module):
|
||||
adapterlayer = AdapterLayer()
|
||||
if module_type.startswith("batch_"):
|
||||
self.batch_layer.pop(module_name)
|
||||
|
||||
def save_split(self, module_name: str, save_name: str):
|
||||
if module_name not in self.name2point:
|
||||
raise ValueError(f"{module_name} not in delta model")
|
||||
module_type, modified_point = self.name2point[module_name]
|
||||
print("Save", module_name, modified_point)
|
||||
|
||||
module_dict = nn.ModuleDict()
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
print("find", key, modified_point, self.find_key(key, modified_point))
|
||||
if self.find_key(key, modified_point): # TODO may have bugs when commonstructure has a virtual node and it's refered
|
||||
logger.debug("find key: {}".format(key))
|
||||
_, _, ref = self.find_module(self.backbone_model, key)
|
||||
|
||||
if key not in self.modified_points:
|
||||
raise ValueError("no module has been added to {}".format(key))
|
||||
|
||||
splitlayer = self.modified_points[key]
|
||||
|
||||
module = splitlayer.split_get(module_name)
|
||||
if module is None:
|
||||
raise ValueError("no module with the name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
module_dict[f"{module_type}:{key.replace('.', ':')}"] = module
|
||||
|
||||
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
|
||||
torch.save(module_dict, save_name)
|
||||
|
||||
def load_split(self, module_name: str, load_name: str):
|
||||
if module_name in self.modified_points:
|
||||
raise ValueError(f"{module_name} already in delta model")
|
||||
|
||||
module_dict = torch.load(load_name)
|
||||
print(module_dict[list(module_dict.keys())[0]], module_dict[list(module_dict.keys())[-1]])
|
||||
|
||||
keys = [key.split(':',maxsplit=1)[1].replace(':', '.') for key in module_dict.keys()]
|
||||
module_types = [key.split(':',maxsplit=1)[0] for key in module_dict.keys()]
|
||||
module_type = module_types[0]
|
||||
print(keys)
|
||||
print(module_type)
|
||||
|
||||
self.name2point[module_name] = (module_type, keys)
|
||||
|
||||
if module_types[0].startswith("batch_"):
|
||||
self.batch_layer[module_name] = []
|
||||
|
||||
for key in self.backbone_key_list:
|
||||
if key in keys:
|
||||
logger.debug("find key: {}".format(key))
|
||||
_, _, ref = self.find_module(self.backbone_model, key)
|
||||
module = module_dict[list(module_dict.keys())[keys.index(key)]]
|
||||
|
||||
if key not in self.modified_points:
|
||||
if module_type == "adapter":
|
||||
splitlayer = SplitSequentialLayer()
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="split_sequential")
|
||||
elif module_type == "lora":
|
||||
splitlayer = SplitParallelLayer()
|
||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="split_parallel")
|
||||
elif module_type == "batch_adapter":
|
||||
splitlayer = BatchSplitSequentialLayer()
|
||||
self.insert_sequential_module(ref, delta_module=splitlayer, delta_name="batchsplit_sequential")
|
||||
elif module_type == "batch_lora":
|
||||
splitlayer = BatchSplitParallelLayer()
|
||||
self.insert_parallel_module(ref, delta_module=splitlayer, delta_name="batchsplit_parallel")
|
||||
self.modified_points[key] = splitlayer
|
||||
|
||||
splitlayer = self.modified_points[key]
|
||||
if (module_type == "adapter" and not isinstance(splitlayer, SplitSequentialLayer)) or \
|
||||
(module_type == "lora" and not isinstance(splitlayer, SplitParallelLayer)) or \
|
||||
(module_type == "batch_adapter" and not isinstance(splitlayer, BatchSplitSequentialLayer)) or \
|
||||
(module_type == "batch_lora" and not isinstance(splitlayer, BatchSplitParallelLayer)):
|
||||
raise ValueError("one modified_point can have at most one module_type")
|
||||
|
||||
if module_type.startswith("batch_"):
|
||||
self.batch_layer[module_name].append(splitlayer)
|
||||
|
||||
if not splitlayer.split_attach(module_name, module):
|
||||
raise ValueError("another module with the same name '{}' has been added to {}".format(module_name, key))
|
||||
|
||||
def new_adapter_like(self, module, **kwargs):
|
||||
adapterlayer = AdapterLayer(**kwargs)
|
||||
self.delta_modules.append(adapterlayer)
|
||||
return adapterlayer
|
||||
|
||||
def new_lora_like(self, child_module):
|
||||
def new_lora_like(self, child_module, **kwargs):
|
||||
if isinstance(child_module, nn.Linear):
|
||||
in_features, out_features = child_module.in_features, child_module.out_features
|
||||
new_module = LowRankLinear(in_features = in_features,
|
||||
out_features = out_features,
|
||||
weight = child_module.weight,
|
||||
r=self.lora_r,
|
||||
lora_alpha=self.lora_alpha,
|
||||
lora_dropout=self.lora_dropout)
|
||||
**kwargs,)
|
||||
self.delta_modules.append(new_module)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
26
test.py
26
test.py
|
@ -1,26 +0,0 @@
|
|||
from transformers import BertModel
|
||||
model = BertModel.from_pretrained("bert-base-cased")
|
||||
from opendelta import Visualization
|
||||
Visualization(model).structure_graph()
|
||||
from opendelta import SplitModel
|
||||
delta_model = SplitModel(model)
|
||||
print("attach here")
|
||||
delta_model.attach("output.dense", "adapter_A", "adapter")
|
||||
delta_model.attach("output.dense", "adapter_B", "adapter")
|
||||
delta_model.attach(["2.output.dense", "2.attention.output.dense"], "adapter_C", "adapter")
|
||||
delta_model.update()
|
||||
# delta_model.attach(["attention.self.query", "attention.self.key"], "lora_A", "lora")
|
||||
delta_model.log() # This will visualize the backbone after modification and other information.
|
||||
import torch
|
||||
x = torch.randint(0, 10, (16, 128)).cuda()
|
||||
import time
|
||||
model = model.cuda()
|
||||
st_time = time.time()
|
||||
print("run here")
|
||||
y = model(x)
|
||||
print("detach here")
|
||||
delta_model.detach("3.output.dense", "adapter_A")
|
||||
delta_model.log()
|
||||
print("run here")
|
||||
y = model(x)
|
||||
print(time.time() - st_time)
|
Loading…
Reference in New Issue