Merge branch 'with_bmtrain' of github.com:thunlp/OpenDelta into with_bmtrain

This commit is contained in:
shengdinghu 2022-10-17 09:11:55 +00:00
commit 2d3fc201d2
12 changed files with 135 additions and 155 deletions

View File

@ -0,0 +1,50 @@
import bmtrain as bmt
import opendelta as od
from opendelta import LoraModel, AdapterModel, CompacterModel, LowRankAdapterModel, BitFitModel
import torch
import numpy
import random
def manual_seed(seed):
torch.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
from model_center.model import Bert, BertConfig
bmt.init_distributed()
config = BertConfig.from_pretrained("/yinxr/zwl/.cache/model_center/bert-base-uncased")
config.dropout_p = 0
model = Bert.from_pretrained("/yinxr/zwl/.cache/model_center/bert-base-uncased", config)
print("before modify")
od.Visualization(model).structure_graph()
manual_seed(233)
delta_model = LoraModel(backbone_model=model, modified_modules=['project_q', 'project_k'])
# delta_model = AdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = CompacterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = LowRankAdapterModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn'])
# delta_model = BitFitModel(backbone_model=model, modified_modules=['[r]layers\\.(\d)+\\.self_att', '[r]layers\\.(\d)+\\.ffn', '[r](.*)layernorm(.*)'])
print(delta_model.delta_modules)
print("after modify")
delta_model.log()
# This will visualize the backbone after modification and other information.
delta_model.freeze_module(exclude=["deltas"], set_state_dict=True)
print("after freeze")
delta_model.log()
# The set_state_dict=True will tell the method to change the state_dict of the backbone_model to maintaining only the trainable parts.
manual_seed(233)
inp = torch.randint(0, 30000, (32, 128)).cuda()
length = torch.randint(0, 128, (32,)).cuda()
attention_mask = (torch.arange(inp.shape[1], device=inp.device)[None, :].repeat(inp.shape[0], 1) < length[:, None])
out = model(inp, attention_mask=attention_mask, output_logits=True).logits
print(out)
if bmt.rank() == 0:
torch.save(model.state_dict(), "test.pt")
ckpt = torch.load("test.pt")
print(ckpt.keys())

View File

@ -0,0 +1 @@
python3 -m torch.distributed.launch --master_addr localhost --master_port 34123 --nproc_per_node $1 --nnodes 1 --node_rank 0 2_with_bmtrain.py

View File

@ -85,6 +85,11 @@ class AdapterLayer(nn.Module, InterFaceMixin):
self.instantiated = True
# initialize the weight, which is important for fast convergence and better performance.
self.apply(self._init_weight)
try:
import bmtrain as bmt
self.modulelist = bmt.BMTrainModelWrapper(self.modulelist)
except:
pass
def _init_weight(self, module):
if isinstance(module, nn.Linear):

View File

@ -2,6 +2,7 @@ from typing import Optional, Union
from opendelta.utils.signature import get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase, is_leaf_module
from opendelta.utils.cuda import get_device, get_dtype
import torch.nn as nn
import torch
@ -28,17 +29,24 @@ class BitFitConfig(BaseDeltaConfig):
setattr(self, arg_name, locals()[arg_name])
class BiasLayer(nn.Module):
def __init__(self, init_method="zero"):
def __init__(self, init_method="zero", dtype=None, device=None):
super().__init__()
self.init_method=init_method
self.instantiated = False
self.dtype = dtype
self.device = device
def instantiate(self, hidden_dim):
if self.init_method == "zero":
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype, device=self.device))
else:
raise NotImplementedError
self.instantiated = True
try:
import bmtrain as bmt
self.bias = bmt.BMTrainModelWrapper(self.bias)
except:
pass
def post_forward(self, output):
r"""Presuming the first argument is the tensor to add bias along the last dimension.
@ -145,7 +153,7 @@ class BitFitModel(DeltaBase):
):
if is_leaf_module(module):
# if it is a leaf module, add bias to it regardless of its type.
if isinstance(module, nn.Linear):
if self.check_linear(module):
self.add_bias_to_linear(module)
else:
# for example, layer_norms, lm_heads.
@ -153,34 +161,43 @@ class BitFitModel(DeltaBase):
else:
# for the non-leaf modules, by default it will add bias only to the linear submodules.
for n, c in module.named_modules():
if isinstance(c, nn.Linear):
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
if self.check_linear(c):
self.add_bias_to_linear(c)
else:
pass
def add_bias_to_linear(self, c):
if c.bias is None:
bias = nn.Parameter(torch.empty(c.out_features), requires_grad=True)
c.register_parameter('bias', bias)
self._reset_bias_parameters(c)
try:
import bmtrain as bmt
bias = bmt.BMTrainModelWrapper(bias)
except:
pass
c.register_parameter('bias', bias)
self.delta_params.append(bias)
else:
c.bias.requires_grad = True
self.delta_params.append(c.bias)
def add_bias_to_others(self, c):
new_bias = BiasLayer()
new_bias = BiasLayer(dtype=get_dtype(c), device=get_device(c))
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
# the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)
def check_linear(self, m):
if isinstance(m, nn.Linear):
return True
else:
try:
from model_center.layer import Linear
if isinstance(m, Linear):
return True
except:
pass
return False
@staticmethod

View File

@ -93,6 +93,13 @@ class HyperComplexAdapterLayer(nn.Module):
phm_init_range=self.phm_init_range,
kronecker_prod=self.kronecker_prod).to(self.device)
self.instantiated = True
try:
import bmtrain as bmt
self.activation = bmt.BMTrainModelWrapper(self.activation)
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
except:
pass
def post_forward(self, output):

View File

@ -62,7 +62,7 @@ def matvec_product(W: torch.Tensor, x: torch.Tensor,
else:
H = kronecker_product_einsum_batched(phm_rule, W).sum(0)
y = torch.matmul(input=x, other=H)
y = torch.matmul(input=x.to(H.dtype), other=H).to(x.dtype)
if bias is not None:
y += bias
return y

View File

@ -33,7 +33,7 @@ class LowRankLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
W = self.W_left*self.W_right
output = torch.matmul(input=x, other=W)
output = torch.matmul(input=x.to(W.dtype), other=W).to(x.dtype)
if self.bias:
output += self.b
return output

View File

@ -137,15 +137,17 @@ class LoraModel(DeltaBase):
pass
def new_module_like(self, child_module):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = LowRankLinear(in_features = in_features,
out_features = out_features,
weight = child_module.weight,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
self.delta_modules.append(new_module)
else:
raise NotImplementedError
in_features, out_features = child_module.in_features, child_module.out_features
new_module = LowRankLinear(in_features = in_features,
out_features = out_features,
weight = child_module.weight,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
try:
import bmtrain as bmt
new_module = bmt.BMTrainModelWrapper(new_module)
except:
pass
self.delta_modules.append(new_module)
return new_module

View File

@ -1,127 +0,0 @@
from typing import Optional, Union
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase
from transformers.models.t5 import T5ForConditionalGeneration
import loralib as lora
import torch.nn as nn
from opendelta import BaseDeltaConfig
class LoraConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~LoraModel`
"""
def __init__(
self,
lora_r=8,
lora_alpha=16,
lora_dropout=0.0,
**kwargs
):
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
class LoraModel(DeltaBase):
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
to replace the linear layer of the backbone model.
class attributes:
- default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the
attention layer. However, other linears can also be modified, and may lead to better performance.
.. note::
modified_modules should point to linear layer. We currently don't support broadcast to all linears in
a module's child modules.
- delta_type = "lora"
Args:
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.
lora_alpha (:obj:`int`, *optional*): A hyper-parameter to control the init scale of loralib.linear .
lora_dropout (:obj:`float`, *optional*): The dropout rate in lora.linear.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
"""
config_class = LoraConfig
delta_type = "lora"
default_modified_modules = ['attn.q', 'attn.v']
def __init__(self,
backbone_model: nn.Module,
lora_r=8,
lora_alpha=16,
lora_dropout=0.0,
modified_modules: Optional[List[str]] = None,
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
DeltaBase.__init__(self,
backbone_model,
modified_modules=modified_modules,
exclude_modules=exclude_modules,
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
def update_module(self, module: nn.Module, key: str):
parent_ref, child_name, child_ref = self.find_module(module, key)
new_module = self.new_module_like(child_module=child_ref)
self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
def _pseudo_data_to_instantiate(self, module):
# no need to pass pseudo input, so overwrite it
pass
def new_module_like(self, child_module):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = lora.Linear(in_features=in_features,
out_features=out_features,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
new_module.weight = child_module.weight
new_module.bias = child_module.bias # if bias is None, also copy
else:
raise NotImplementedError
return new_module
def mark_as_delta(self, module: nn.Module = None):
if module is None:
module=self
for n, p in module.named_parameters():
param_name = n.split(".")[-1]
if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
setattr(p, "_is_delta", True)

View File

@ -69,6 +69,13 @@ class LowRankAdapter(nn.Module):
rank=self.low_rank_rank).to(self.device)
self.instantiated = True
try:
import bmtrain as bmt
self.activation = bmt.BMTrainModelWrapper(self.activation)
self.down_sampler = bmt.BMTrainModelWrapper(self.down_sampler)
self.up_sampler = bmt.BMTrainModelWrapper(self.up_sampler)
except:
pass
def post_forward(self, output):
r""" Get the hidden_states from the PLM's layer output, pass it into the low-rank adapter,
@ -88,7 +95,6 @@ class LowRankAdapter(nn.Module):
logger.debug(f"Got hidden dim hidden_dim {self.hidden_dim}")
self.instantiate(hidden_dim=self.hidden_dim)
z = self.down_sampler(hiddens)
z = self.activation(z)
adapter_output = self.up_sampler(z)

View File

@ -225,5 +225,10 @@ class SoftPromptModel(DeltaBase):
init_range = self.init_range,
device = module_device,
)
try:
import bmtrain as bmt
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
except:
pass
self.delta_modules.append(soft_prompt_layer)
return soft_prompt_layer

View File

@ -17,6 +17,20 @@ def get_device(module : Union[nn.Module, nn.Parameter]):
else:
raise RuntimeError("The module is paralleled acrossed device, please get device in a inner module")
def get_dtype(module : Union[nn.Module, nn.Parameter]):
if not (isinstance(module, nn.Module) \
or isinstance(module, nn.Parameter)):
raise RuntimeError("module is not a instance of torch.nn.Module")
if hasattr(module, 'dtype'):
return module.dtype
else:
params_dtypes = [p.dtype for p in module.parameters()]
if len(params_dtypes) == 0:
return None
elif len(set(params_dtypes))==1:
return params_dtypes[0]
else:
raise RuntimeError("The module has multiple dtype, please get device in a inner module")
def move_dict_to_cuda(dict_of_tensor, device):
for key in dict_of_tensor: