2022-02-14 21:19:03 +08:00
|
|
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
|
|
|
from opendelta.utils.name_based_addressing import *
|
|
|
|
from opendelta.utils.cuda import get_device
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
from typing import *
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from opendelta import BaseDeltaConfig
|
|
|
|
from decorator import decorate
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from opendelta import logging
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
class SoftPromptConfig(BaseDeltaConfig):
|
|
|
|
r"""
|
|
|
|
This is the configuration class to store the configuration of a :py:class:`SoftPromptModel`
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
2022-04-14 11:22:41 +08:00
|
|
|
self,
|
2022-02-14 21:19:03 +08:00
|
|
|
soft_token_num=100,
|
|
|
|
init_range = 0.5,
|
|
|
|
token_init = True,
|
|
|
|
**kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
):
|
2022-02-14 21:19:03 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
class SoftPromptLayer(nn.Module):
|
|
|
|
r"""This is the implementation of `The Power of Scale for Parameter-Efficient
|
|
|
|
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
|
|
|
|
This template also does not need any textual template. Addition tokens are directly
|
2022-04-14 11:22:41 +08:00
|
|
|
concatenated into the input ids. There are two initializations of the new tokens.
|
|
|
|
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
|
2022-02-14 21:19:03 +08:00
|
|
|
the first n_tokens similar to their implementation).
|
|
|
|
|
|
|
|
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
|
|
|
|
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
soft_token_num: int = 100,
|
|
|
|
raw_embedding: Optional[torch.Tensor] = None,
|
|
|
|
init_range: Optional[float] = 0.5,
|
|
|
|
token_init = False,
|
|
|
|
pad_id = 0,
|
|
|
|
device: Optional[str]=None,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.__dict__['raw_embedding'] = raw_embedding
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
self.init_range = init_range
|
|
|
|
self.num_tokens = soft_token_num
|
|
|
|
self.pad_id = pad_id
|
|
|
|
self.token_init = token_init
|
|
|
|
self.device = device
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
assert self.num_tokens>0
|
|
|
|
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
|
|
|
|
|
|
|
def pre_forward(self, *args, **kwargs):
|
|
|
|
# if attention_mask is passed as PLM's input, modify it here
|
2022-04-14 11:22:41 +08:00
|
|
|
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
|
2022-02-14 21:19:03 +08:00
|
|
|
# In generation, the input is forward through the model again.
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
if 'input_ids' in kwargs:
|
|
|
|
input_ids = kwargs['input_ids']
|
|
|
|
kwargs['input_ids'] = None
|
|
|
|
elif len(args) > 0:
|
|
|
|
input_ids = args[0]
|
|
|
|
args = args[1:]
|
|
|
|
else:
|
|
|
|
input_ids = None
|
|
|
|
|
|
|
|
|
|
|
|
if 'attention_mask' not in kwargs or kwargs['attention_mask'] is None:
|
|
|
|
# infer attention mask
|
|
|
|
if input_ids is None:
|
|
|
|
raise RuntimeError("no input ids found")
|
|
|
|
kwargs['attention_mask'] = (input_ids != self.pad_id).to(torch.int64)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
if 'inputs_embeds' not in kwargs or kwargs['inputs_embeds'] is None:
|
|
|
|
try:
|
|
|
|
inputs_embeds = self.raw_embedding(input_ids)
|
|
|
|
except:
|
|
|
|
raise RuntimeError("neither inputs_embeds nor input_ids is specified.")
|
|
|
|
else:
|
|
|
|
inputs_embeds = kwargs['inputs_embeds']
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
batch_size = inputs_embeds.size(0)
|
|
|
|
soft_embeds = self.soft_embeds.repeat(batch_size, 1, 1)
|
|
|
|
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
|
|
|
|
kwargs['inputs_embeds'] = inputs_embeds
|
|
|
|
|
|
|
|
am = kwargs['attention_mask']
|
|
|
|
am.data = torch.cat([torch.ones((*am.shape[:-1], inputs_embeds.shape[-2]-am.shape[-1]), dtype = am.dtype,device=am.device), am], dim=-1)
|
|
|
|
|
|
|
|
return args, kwargs
|
|
|
|
|
|
|
|
def instantiate(self, hidden_dim) -> None:
|
|
|
|
"""
|
|
|
|
generate parameters needed for soft tokens embedding in soft-prompt
|
|
|
|
for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
|
|
|
|
"""
|
|
|
|
soft_embeds = torch.FloatTensor(self.num_tokens, hidden_dim)
|
|
|
|
if self.token_init:
|
|
|
|
soft_embeds.data = torch.clone(self.raw_embedding(torch.tensor([i for i in range(self.num_tokens)])))
|
|
|
|
else:
|
|
|
|
soft_embeds = soft_embeds.uniform_(-self.init_range, self.init_range)
|
|
|
|
|
|
|
|
self.soft_embeds = nn.Parameter(soft_embeds, requires_grad=True).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
class SoftPromptModel(DeltaBase):
|
|
|
|
r"""
|
|
|
|
This is the implementation of `The Power of Scale for Parameter-Efficient
|
|
|
|
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
|
|
|
|
This template also does not need any textual template. Addition tokens are directly
|
2022-04-14 11:22:41 +08:00
|
|
|
concatenated into the input ids. There are two initializations of the new tokens.
|
|
|
|
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
|
2022-02-14 21:19:03 +08:00
|
|
|
the first n_tokens similar to their implementation).
|
|
|
|
|
|
|
|
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
|
|
|
|
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
2022-02-14 21:19:03 +08:00
|
|
|
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
|
2022-03-19 15:04:42 +08:00
|
|
|
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
|
2022-02-14 21:19:03 +08:00
|
|
|
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
|
|
|
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
|
|
|
the implemented ones)
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
|
|
|
together with the prefix parameters.
|
2022-03-19 15:04:42 +08:00
|
|
|
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
config_class = SoftPromptConfig
|
|
|
|
delta_type = "soft_prompt"
|
|
|
|
default_modified_modules = ["root"] # not used
|
2022-04-14 11:22:41 +08:00
|
|
|
def __init__(self,
|
2022-02-14 21:19:03 +08:00
|
|
|
backbone_model: nn.Module,
|
|
|
|
soft_token_num=100,
|
|
|
|
init_range = 0.5,
|
|
|
|
token_init=True,
|
2022-03-19 15:04:42 +08:00
|
|
|
modified_modules: Optional[List[str]] = None,
|
2022-04-14 11:22:41 +08:00
|
|
|
exclude_modules: Optional[List[str]] = None,
|
2022-03-19 15:04:42 +08:00
|
|
|
unfrozen_modules: Optional[List[str]] = None,
|
2022-02-14 21:19:03 +08:00
|
|
|
common_structure: Optional[bool] = None,
|
|
|
|
interactive_modify: Optional[Union[bool, int]] = False,
|
|
|
|
):
|
2022-04-14 11:22:41 +08:00
|
|
|
DeltaBase.__init__(self,
|
2022-02-14 21:19:03 +08:00
|
|
|
backbone_model = backbone_model,
|
|
|
|
modified_modules = ["root"],
|
2022-04-14 11:22:41 +08:00
|
|
|
exclude_modules = exclude_modules,
|
2022-02-14 21:19:03 +08:00
|
|
|
unfrozen_modules = unfrozen_modules,
|
|
|
|
common_structure = False,
|
|
|
|
interactive_modify = interactive_modify,
|
|
|
|
)
|
|
|
|
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # not registered in parent class
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
self.__dict__['raw_embedding'] = self.backbone_model.get_input_embeddings()
|
|
|
|
except AttributeError:
|
|
|
|
raise AttributeError(f"'{type(self.backbone_model)}' object has no attribute 'get_input_embeddings', please pass "+
|
|
|
|
"input embeddings into 'self.raw_embedding' in user-specific ways.")
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
self.delta_modules = nn.ModuleList()
|
|
|
|
self.add_all_delta_to_backbone(self.backbone_model,
|
|
|
|
self.modified_modules,
|
|
|
|
)
|
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
def add_all_delta_to_backbone(self,
|
|
|
|
module: nn.Module,
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules: List[str],
|
|
|
|
) -> nn.Module:
|
|
|
|
self.update_module()
|
|
|
|
self.mark_as_delta()
|
|
|
|
return module
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def update_module(self):
|
|
|
|
soft_prompt_layer = self.new_module_like(self.raw_embedding)
|
|
|
|
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
|
2022-04-14 11:22:41 +08:00
|
|
|
delta_module=soft_prompt_layer,
|
2022-02-20 17:23:31 +08:00
|
|
|
delta_name="soft_prompt_layer" )
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
def new_module_like(self, module):
|
|
|
|
module_device = get_device(module)
|
2022-02-20 17:23:31 +08:00
|
|
|
soft_prompt_layer = SoftPromptLayer(
|
2022-02-14 21:19:03 +08:00
|
|
|
soft_token_num = self.soft_token_num,
|
|
|
|
raw_embedding = self.raw_embedding,
|
|
|
|
token_init = self.token_init,
|
|
|
|
init_range = self.init_range,
|
|
|
|
device = module_device,
|
|
|
|
)
|
2022-04-14 11:22:41 +08:00
|
|
|
self.delta_modules.append(soft_prompt_layer)
|
2022-02-14 21:19:03 +08:00
|
|
|
return soft_prompt_layer
|