2022-02-14 21:19:03 +08:00
|
|
|
from functools import partial
|
|
|
|
from opendelta.delta_configs import BaseDeltaConfig
|
|
|
|
from opendelta.utils.signature import get_arg_names_inside_func, signature
|
|
|
|
from typing import Optional, Union
|
|
|
|
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
|
|
|
|
from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention
|
2022-06-06 16:21:55 +08:00
|
|
|
from transformers.models.bert.modeling_bert import BertAttention
|
2022-02-14 21:19:03 +08:00
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
|
|
|
from transformers.models.bart.modeling_bart import BartAttention
|
|
|
|
from transformers.models.roberta.modeling_roberta import RobertaAttention
|
|
|
|
from opendelta.utils.name_based_addressing import *
|
|
|
|
from opendelta.utils.cuda import get_device
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
from transformers.models.t5 import T5ForConditionalGeneration
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch
|
|
|
|
import opendelta.utils.logging as logging
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PrefixLayerT5(nn.Module):
|
|
|
|
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
|
|
|
into the original attention layer's forward function.
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, num_heads, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.device = device
|
|
|
|
self.instantiated = False
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
|
|
|
"""
|
|
|
|
batch_size = args[0].shape[0]
|
|
|
|
seq_len = args[0].shape[-2]
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = args[0].shape[-1]
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
|
|
|
|
|
|
|
|
def expand_batchsize(x):
|
|
|
|
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
|
|
|
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
|
|
|
return x
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'position_bias' in kwargs and kwargs['position_bias'] is not None:
|
2022-04-14 11:22:41 +08:00
|
|
|
if kwargs['position_bias'].shape[-1] != seq_len + self.prefix_token_num: # Then the position_bias should be re-calculated
|
2022-02-14 21:19:03 +08:00
|
|
|
kwargs['position_bias'] = None
|
|
|
|
if kwargs['past_key_value'] is None:
|
|
|
|
kwargs['past_key_value'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
past_key_len = kwargs['past_key_value'][0].shape[-2]
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'mask' in kwargs and kwargs['mask'] is not None:
|
|
|
|
mask_len = kwargs['mask'].shape[-1]
|
|
|
|
if past_key_len + seq_len == mask_len + self.prefix_token_num:
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
am = kwargs['mask'] # Should check the format of the attention_mask when moving to a new plm.
|
|
|
|
kwargs['mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
return args, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def post_forward(self, output):
|
2022-04-14 11:22:41 +08:00
|
|
|
r""" Remove the cached positional bias, since the next layer may not have prefix token.
|
2022-02-14 21:19:03 +08:00
|
|
|
"""
|
|
|
|
output = output[:2] + (None, )+ output[3:]
|
|
|
|
return output
|
|
|
|
|
|
|
|
class PrefixLayerBart(nn.Module):
|
|
|
|
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
|
|
|
into the original attention layer's forward function.
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, num_heads, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.device = device
|
|
|
|
self.instantiated = False
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
|
|
|
"""
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
batch_size = kwargs['hidden_states'].shape[0]
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = kwargs['hidden_states'].shape[-1]
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
|
|
|
|
# from IPython import embed
|
|
|
|
# embed()
|
|
|
|
def expand_batchsize(x):
|
|
|
|
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
|
|
|
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
|
|
|
return x
|
|
|
|
# from IPython import embe
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
|
|
|
|
kwargs['past_key_value'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
|
|
|
|
am = kwargs['attention_mask'] # Should check the format of the attention_mask when moving to a new plm.
|
|
|
|
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
return args, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
class PrefixLayerGPT2(nn.Module):
|
|
|
|
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
|
|
|
into the original attention layer's forward function.
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, num_heads, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.device = device
|
|
|
|
self.instantiated = False
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
|
|
|
"""
|
|
|
|
batch_size = args[0].shape[0]
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = args[0].shape[-1]
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
|
|
|
|
def expand_batchsize(x):
|
|
|
|
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
|
|
|
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
|
|
|
return x
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if kwargs['layer_past'] is None:
|
|
|
|
kwargs['layer_past'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
|
|
|
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
|
|
|
|
am = kwargs['attention_mask'] # Should check the format of the attention_mask when moving to a new plm.
|
|
|
|
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
return args, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
class PrefixLayerDistilBert(nn.Module):
|
|
|
|
# TODO: Warning: have bugs
|
|
|
|
def __init__(self, prefix_token_num, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.device = device
|
|
|
|
self.key_instantiated = False
|
|
|
|
self.value_instantiated = False
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
mask = kwargs["mask"]
|
|
|
|
key, value = kwargs['key'], kwargs['value']
|
|
|
|
prefix_mask = torch.ones(mask.shape[0], self.prefix_token_num, dtype=mask.dtype, device=mask.device)
|
|
|
|
concated_mask = torch.cat([prefix_mask, mask], dim=1)
|
|
|
|
pseudo_prefix = torch.zeros(key.shape[0], self.prefix_token_num, key.shape[2], dtype=key.dtype, device=key.device)
|
|
|
|
concated_key = torch.cat([pseudo_prefix, key], dim=1)
|
|
|
|
concated_value = torch.cat([pseudo_prefix, value], dim=1)
|
|
|
|
kwargs["mask"] = concated_mask
|
|
|
|
kwargs['key'] = concated_key
|
|
|
|
kwargs['value'] = concated_value
|
|
|
|
return args, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def key_instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.key_instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def value_instantiate(self, hidden_dim):
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.value_instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def key_pre_forward(self, *args, **kwargs):
|
|
|
|
_input = args[0]
|
|
|
|
_input = _input[:,self.prefix_token_num:, :]
|
|
|
|
args = (_input,) +args[1:]
|
|
|
|
return args, kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def value_pre_forward(self, *args, **kwargs):
|
|
|
|
_input = args[0]
|
|
|
|
_input = _input[:,self.prefix_token_num:, :]
|
|
|
|
args = (_input,) +args[1:]
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
def key_forward(self, output: torch.Tensor): ### Check whether run prefix is ok, 12.21
|
|
|
|
if isinstance(output, torch.Tensor):
|
|
|
|
hiddens = output
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
if not self.key_instantiated:
|
|
|
|
self.hidden_dim = hiddens.shape[-1]
|
|
|
|
logger.debug(f"Got key hidden dim hidden_dim {self.hidden_dim}")
|
|
|
|
self.key_instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
batch_size = hiddens.shape[0]
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
output = torch.cat([past_key.unsqueeze(0).expand(batch_size, *past_key.shape), hiddens], dim=1)
|
|
|
|
return output
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def value_forward(self, output: torch.Tensor): ### Check whether run prefix is ok, 12.21
|
|
|
|
if isinstance(output, torch.Tensor):
|
|
|
|
hiddens = output
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
if not self.value_instantiated:
|
|
|
|
self.hidden_dim = hiddens.shape[-1]
|
|
|
|
logger.debug(f"Got value hidden dim hidden_dim {self.hidden_dim}")
|
|
|
|
self.value_instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
batch_size = hiddens.shape[0]
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
output = torch.cat([past_value.unsqueeze(0).expand(batch_size, *past_value.shape), hiddens], dim=1)
|
|
|
|
return output
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
2022-06-06 16:21:55 +08:00
|
|
|
class PrefixLayerBert(nn.Module):
|
|
|
|
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
|
|
|
into the original attention layer's forward function.
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, num_heads, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.device = device
|
|
|
|
self.instantiated = False
|
|
|
|
|
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.instantiated = True
|
|
|
|
|
|
|
|
|
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
|
|
|
"""
|
|
|
|
batch_size = args[0].shape[0]
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = args[0].shape[-1]
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
|
|
|
|
|
|
|
|
def expand_batchsize(x):
|
|
|
|
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
|
|
|
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
|
|
|
return x
|
|
|
|
# from IPython import embe
|
|
|
|
|
|
|
|
if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
|
|
|
|
kwargs['past_key_value'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
|
|
|
|
|
|
|
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
|
|
|
|
am = kwargs['attention_mask'] # Should check the format of the attention_mask when moving to a new plm.
|
|
|
|
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
elif len(args) >1: # attention mask is passed via positional argument
|
|
|
|
am = args[1]
|
|
|
|
am = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
args = (args[0], am) + args[2:]
|
|
|
|
# from IPython import embed
|
|
|
|
# embed(header = "Herein prefixroberta")
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
class PrefixLayerRoberta(nn.Module):
|
|
|
|
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
|
|
|
into the original attention layer's forward function.
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, num_heads, device,):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.device = device
|
|
|
|
self.instantiated = False
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def instantiate(self, hidden_dim):
|
|
|
|
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
|
|
|
self.past_key_reparam = None
|
|
|
|
self.past_value_reparam = None
|
|
|
|
self.instantiated = True
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
|
|
|
"""
|
|
|
|
batch_size = args[0].shape[0]
|
|
|
|
if not self.instantiated:
|
|
|
|
self.hidden_dim = args[0].shape[-1]
|
|
|
|
self.instantiate(hidden_dim=self.hidden_dim)
|
|
|
|
if self.past_key_reparam is None:
|
|
|
|
past_key = self.past_key.data
|
|
|
|
else:
|
|
|
|
past_key = self.past_key_reparam
|
|
|
|
if self.past_value_reparam is None:
|
|
|
|
past_value = self.past_value.data
|
|
|
|
else:
|
|
|
|
past_value = self.past_value_reparam
|
|
|
|
|
|
|
|
# from IPython import embed
|
|
|
|
# embed()
|
|
|
|
def expand_batchsize(x):
|
|
|
|
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
|
|
|
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
|
|
|
return x
|
|
|
|
# from IPython import embe
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
|
|
|
|
kwargs['past_key_value'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
|
|
|
|
am = kwargs['attention_mask'] # Should check the format of the attention_mask when moving to a new plm.
|
|
|
|
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
elif len(args) >1: # attention mask is passed via positional argument
|
|
|
|
am = args[1]
|
|
|
|
am = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
args = (args[0], am) + args[2:]
|
|
|
|
# from IPython import embed
|
|
|
|
# embed(header = "Herein prefixroberta")
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
# def post_forward(self, output):
|
2022-04-14 11:22:41 +08:00
|
|
|
# r""" Remove the cached positional bias, since the next layer may not have prefix token.
|
2022-02-14 21:19:03 +08:00
|
|
|
# """
|
|
|
|
# output = output[:2] + (None, )+ output[3:]
|
|
|
|
# return output
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
class ReparameterizeFunction(nn.Module):
|
|
|
|
r""" Prefix Tuning's performance is better with a reparameterize module, which generates
|
|
|
|
the ``past_key_value`` using an MLP instead of directly optimizing the ``past_key_value`` as leaf variable.
|
2022-04-14 11:22:41 +08:00
|
|
|
In our implementation, the reparameterize module is constructed according to the number of parameters
|
2022-02-14 21:19:03 +08:00
|
|
|
in all ``past_key_value``s. Thus, variable number of prefixlayer is supported (not restricting to being equal
|
|
|
|
to the number of layers of the pretraind language model)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(self, prefix_token_num, embed_dim, dropout_rate=0.0, mid_dim=512, module_list=[]):
|
|
|
|
super().__init__()
|
|
|
|
self.prefix_token_num = prefix_token_num
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.mid_dim = mid_dim
|
|
|
|
self.module_list = module_list
|
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
|
self.record_parameters()
|
|
|
|
self.compatibility_check()
|
|
|
|
self.define_reparameterization_network()
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def record_parameters(self):
|
|
|
|
r""" Enumerate the parameters that need to be reparameterized.
|
2022-04-14 11:22:41 +08:00
|
|
|
Then, delete the original parameters.
|
2022-02-14 21:19:03 +08:00
|
|
|
"""
|
|
|
|
tot = 0
|
|
|
|
for module in self.module_list:
|
|
|
|
for n, parameters in module.named_parameters():
|
|
|
|
tot += parameters.numel()
|
|
|
|
module.register_parameter(n, None)
|
|
|
|
self.total_parameters_num = tot
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def compatibility_check(self,):
|
|
|
|
r"""May be removed.
|
|
|
|
"""
|
|
|
|
assert self.total_parameters_num % self.prefix_token_num == 0
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def allocate_parameter(self):
|
2022-04-14 11:22:41 +08:00
|
|
|
r""" At the beginning of each forward pass through the whole network(PLM),
|
2022-02-14 21:19:03 +08:00
|
|
|
cacalulate the reparameterized past_key and past_value (``past_key_reparam`` and ``past_value_reparam``)
|
|
|
|
for later use in each layer.
|
|
|
|
"""
|
|
|
|
input_tokens = self.input_tokens
|
|
|
|
temp_control = self.wte(input_tokens)
|
|
|
|
past_key_values = self.control_trans(temp_control)
|
|
|
|
seqlen, _ = past_key_values.shape
|
|
|
|
|
|
|
|
past_key_values = past_key_values.view(seqlen, len(self.module_list) * 2, self.module_list[0].hidden_dim)
|
|
|
|
past_key_values = self.dropout(past_key_values)
|
|
|
|
past_key_values = past_key_values.permute([1, 0, 2]).split(2)
|
|
|
|
|
|
|
|
for module_id, module in enumerate(self.module_list):
|
|
|
|
module.past_key_reparam = past_key_values[module_id][0]
|
|
|
|
module.past_value_reparam = past_key_values[module_id][1]
|
|
|
|
|
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
r""" Firstly forward through the reparameterized network, and then go through normal forward pass of the PLM.
|
|
|
|
"""
|
|
|
|
self.allocate_parameter()
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
def define_reparameterization_network(self) -> None:
|
2022-04-14 11:22:41 +08:00
|
|
|
r""" Build the reparameterize module
|
2022-02-14 21:19:03 +08:00
|
|
|
"""
|
|
|
|
self.input_tokens = nn.Parameter(torch.arange(self.prefix_token_num).long(), requires_grad=False) # to allow automatic devicing
|
|
|
|
self.wte = nn.Embedding(self.prefix_token_num, self.embed_dim)
|
|
|
|
self.control_trans = nn.Sequential(
|
|
|
|
nn.Linear(self.embed_dim, self.mid_dim),
|
|
|
|
nn.Tanh(),
|
|
|
|
nn.Linear(self.mid_dim, self.total_parameters_num//self.prefix_token_num)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class PrefixConfig(BaseDeltaConfig):
|
|
|
|
def __init__(
|
2022-04-14 11:22:41 +08:00
|
|
|
self,
|
2022-02-14 21:19:03 +08:00
|
|
|
prefix_token_num=6,
|
|
|
|
reparameterize=True,
|
|
|
|
embed_dim: Optional[int]=512,
|
|
|
|
mid_dim: Optional[int]=512,
|
|
|
|
**kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
):
|
2022-02-14 21:19:03 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrefixModel(DeltaBase):
|
|
|
|
r""" The implementation of `Prefix-Tuning: Optimizing Continuous Prompts for Generation <https://arxiv.org/abs/2101.00190>`_ .
|
|
|
|
However, as attention block of different PLM differs substantially, e.g., the input arguments, the name convention
|
|
|
|
of ``past_key_value``, we have to implement different prefixlayer for different PLM. Given the inconvenience in the
|
2022-04-14 11:22:41 +08:00
|
|
|
code level, we only support several commonly used backbone models (Currently: T5, DistilBert,Bert, Roberta, GPT2,
|
|
|
|
BART). If you are trying to apply delta tuning to other backbone models, we suggest you trying other delta models
|
|
|
|
or implementing it and making a pull request.
|
|
|
|
|
|
|
|
Experimental Feature:
|
|
|
|
|
|
|
|
Support inserting prefix token before each layer. For example, layer 3 4 6 10 and other layer untouched.
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
.. note::
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
If using reparameterize, the parameters will be in a reparameterization network, not in the prefix, which
|
2022-04-14 11:22:41 +08:00
|
|
|
we attach to the first prefix layer. We will add a function to save only the generated prefix parameters for
|
2022-02-14 21:19:03 +08:00
|
|
|
saving in the next version.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
2022-02-14 21:19:03 +08:00
|
|
|
prefix_token_num (:obj:`int`): the number of prefix token
|
|
|
|
reparameterize (:obj:`bool`): Whether use the reparameterization for prefix tuning.
|
|
|
|
embed_dim (:obj:`int`): The embeding dimension of prefix token when using the reparameterization.
|
|
|
|
mid_dim (:obj:`int`): The dimension of the hiddens of the reparameterization network.
|
|
|
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
|
|
|
the implemented ones)
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
|
|
|
together with the prefix parameters.
|
2022-03-19 15:04:42 +08:00
|
|
|
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
config_class = PrefixConfig
|
|
|
|
delta_type = "prefix"
|
|
|
|
default_modified_modules = ['attn']
|
|
|
|
def __init__(self,
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model: nn.Module,
|
2022-02-14 21:19:03 +08:00
|
|
|
prefix_token_num=6,
|
|
|
|
reparameterize=True,
|
|
|
|
embed_dim: Optional[int]=512,
|
|
|
|
mid_dim: Optional[int]=512,
|
|
|
|
modified_modules: Optional[List[str]] = None,
|
2022-04-14 11:22:41 +08:00
|
|
|
exclude_modules: Optional[List[str]] = None,
|
2022-02-14 21:19:03 +08:00
|
|
|
unfrozen_modules: Optional[List[str]] = None,
|
|
|
|
common_structure: Optional[bool] = None,
|
|
|
|
interactive_modify: Optional[Union[bool, int]] = False,
|
|
|
|
):
|
2022-04-14 11:22:41 +08:00
|
|
|
DeltaBase.__init__(self,
|
|
|
|
backbone_model,
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules=modified_modules,
|
2022-04-14 11:22:41 +08:00
|
|
|
exclude_modules=exclude_modules,
|
2022-02-14 21:19:03 +08:00
|
|
|
unfrozen_modules=unfrozen_modules,
|
|
|
|
common_structure=common_structure,
|
|
|
|
interactive_modify=interactive_modify,
|
|
|
|
)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # not registered in parent class
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
self.delta_modules = nn.ModuleList()
|
|
|
|
|
|
|
|
self.add_all_delta_to_backbone(self.backbone_model,
|
|
|
|
self.modified_modules,
|
|
|
|
)
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
def add_all_delta_to_backbone(self,
|
|
|
|
module: nn.Module,
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules: List[str],
|
|
|
|
) -> nn.Module:
|
2022-04-14 11:22:41 +08:00
|
|
|
first_modified_module = None
|
|
|
|
# Current, We assume the layerer are in order in named_modules.
|
2022-02-14 21:19:03 +08:00
|
|
|
# Thus the first modified module is the first module that the tensor flows to.
|
2022-04-14 11:22:41 +08:00
|
|
|
for key, _ in module.named_modules():
|
2022-02-14 21:19:03 +08:00
|
|
|
if self.find_key(key, modified_modules):
|
|
|
|
logger.debug("find key {}".format(key))
|
|
|
|
if first_modified_module is None:
|
|
|
|
_, _, ref = self.find_module(module, key)
|
|
|
|
first_modified_module = ref
|
|
|
|
self.update_module(module, key)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
self._pseudo_data_to_instantiate(module)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if self.reparameterize:
|
2022-04-14 11:22:41 +08:00
|
|
|
reparams = ReparameterizeFunction(prefix_token_num=self.prefix_token_num,
|
|
|
|
embed_dim=self.embed_dim,
|
|
|
|
mid_dim=self.mid_dim,
|
2022-02-14 21:19:03 +08:00
|
|
|
module_list=self.delta_modules)
|
|
|
|
self.delta_modules = None
|
|
|
|
self.reparams = reparams
|
2022-02-20 17:23:31 +08:00
|
|
|
self.insert_sequential_module(first_modified_module, delta_module=reparams, delta_name="reparams", strict=False)
|
2022-02-14 21:19:03 +08:00
|
|
|
self.mark_as_delta()
|
|
|
|
return module
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def update_module(self, module: nn.Module, key: str):
|
|
|
|
_, _, ref = self.find_module(module, key)
|
|
|
|
|
|
|
|
prefixlayer, ref = self.new_module_like(ref)
|
2022-02-20 17:23:31 +08:00
|
|
|
self.insert_sequential_module(ref, delta_module=prefixlayer, delta_name="prefix")
|
2022-04-14 11:22:41 +08:00
|
|
|
self.delta_modules.append(prefixlayer)
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def new_module_like(self, module):
|
|
|
|
# TODO: support more Attention modules
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
if isinstance(module, T5Attention) or isinstance(module, T5LayerSelfAttention):
|
2022-02-14 21:19:03 +08:00
|
|
|
if isinstance(module, T5LayerSelfAttention):
|
|
|
|
module = module.SelfAttention # innermodule
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerT5(prefix_token_num=self.prefix_token_num, num_heads=module.n_heads ,device=module_device)
|
|
|
|
elif isinstance(module, MultiHeadSelfAttention): # MultiHeadSelfAttention didn't provide past_key_value in the interface of the forward function.
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerDistilBert(prefix_token_num=self.prefix_token_num, device=module_device)
|
|
|
|
self.insert_sequential_module(getattr(module, "k_lin"), pre_caller=prefixlayer.key_pre_forward, post_caller=prefixlayer.key_forward)
|
|
|
|
self.insert_sequential_module(getattr(module, "v_lin"), pre_caller=prefixlayer.value_pre_forward, post_caller=prefixlayer.value_forward)
|
2022-06-06 16:21:55 +08:00
|
|
|
elif isinstance(module, BertAttention):
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerBert(prefix_token_num=self.prefix_token_num, num_heads=module.self.num_attention_heads ,device=module_device)
|
2022-02-14 21:19:03 +08:00
|
|
|
elif isinstance(module, RobertaAttention):
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerRoberta(prefix_token_num=self.prefix_token_num, num_heads=module.self.num_attention_heads,device=module_device)
|
|
|
|
elif isinstance(module, GPT2Attention):
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerGPT2(prefix_token_num=self.prefix_token_num, num_heads=module.num_heads ,device=module_device)
|
|
|
|
elif isinstance(module, BartAttention):
|
|
|
|
module_device = get_device(module)
|
|
|
|
prefixlayer = PrefixLayerBart(prefix_token_num=self.prefix_token_num, num_heads=module.num_heads ,device=module_device)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(type(module))
|
|
|
|
return prefixlayer, module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|