OpenDeltaMirror/opendelta/delta_models/soft_prompt.py

232 lines
10 KiB
Python
Raw Normal View History

2022-02-14 21:19:03 +08:00
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase
from typing import *
import torch
import torch.nn as nn
from opendelta import BaseDeltaConfig
from opendelta import logging
logger = logging.get_logger(__name__)
class SoftPromptConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`SoftPromptModel`
"""
def __init__(
2022-04-14 11:22:41 +08:00
self,
2022-02-14 21:19:03 +08:00
soft_token_num=100,
init_range = 0.5,
token_init = True,
**kwargs
2022-04-14 11:22:41 +08:00
):
2022-02-14 21:19:03 +08:00
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
class SoftPromptLayer(nn.Module):
r"""This is the implementation of `The Power of Scale for Parameter-Efficient
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
This template also does not need any textual template. Addition tokens are directly
2022-04-14 11:22:41 +08:00
concatenated into the input ids. There are two initializations of the new tokens.
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
2022-02-14 21:19:03 +08:00
the first n_tokens similar to their implementation).
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
"""
def __init__(self,
soft_token_num: int = 100,
raw_embedding: Optional[torch.Tensor] = None,
init_range: Optional[float] = 0.5,
2022-04-18 23:28:13 +08:00
other_expand_ids: Optional[Dict] = {"attention_mask":1, "token_type_ids":0},
2022-02-14 21:19:03 +08:00
token_init = False,
pad_id = 0,
device: Optional[str]=None,
):
super().__init__()
self.__dict__['raw_embedding'] = raw_embedding
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
self.init_range = init_range
self.num_tokens = soft_token_num
self.pad_id = pad_id
self.token_init = token_init
self.device = device
2022-04-18 23:28:13 +08:00
self.other_expand_ids = other_expand_ids
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
assert self.num_tokens>0
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
2022-04-22 19:18:25 +08:00
# self.all_pseudo_tokens = {}
2022-04-18 23:28:13 +08:00
2022-02-14 21:19:03 +08:00
def pre_forward(self, *args, **kwargs):
# if attention_mask is passed as PLM's input, modify it here
2022-04-14 11:22:41 +08:00
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
2022-02-14 21:19:03 +08:00
# In generation, the input is forward through the model again.
return args, kwargs
if 'input_ids' in kwargs:
input_ids = kwargs['input_ids']
kwargs['input_ids'] = None
elif len(args) > 0:
input_ids = args[0]
args = args[1:]
else:
input_ids = None
if 'attention_mask' not in kwargs or kwargs['attention_mask'] is None:
# infer attention mask
if input_ids is None:
raise RuntimeError("no input ids found")
kwargs['attention_mask'] = (input_ids != self.pad_id).to(torch.int64)
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
if 'inputs_embeds' not in kwargs or kwargs['inputs_embeds'] is None:
try:
inputs_embeds = self.raw_embedding(input_ids)
except:
raise RuntimeError("neither inputs_embeds nor input_ids is specified.")
else:
inputs_embeds = kwargs['inputs_embeds']
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
batch_size = inputs_embeds.size(0)
soft_embeds = self.soft_embeds.repeat(batch_size, 1, 1)
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
kwargs['inputs_embeds'] = inputs_embeds
2022-04-18 23:28:13 +08:00
for expand_key in self.other_expand_ids:
if expand_key in kwargs:
real_tokens = kwargs[expand_key]
2022-04-22 19:18:25 +08:00
# if expand_key in self.all_pseudo_tokens:
# pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
# else:
pseudo_tokens_value = self.other_expand_ids[expand_key]
pseudo_tokens = torch.ones(
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
dtype = real_tokens.dtype,
device=real_tokens.device) * pseudo_tokens_value
# self.all_pseudo_tokens[expand_key] = pseudo_tokens
2022-04-18 23:28:13 +08:00
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
2022-02-14 21:19:03 +08:00
return args, kwargs
def instantiate(self, hidden_dim) -> None:
"""
generate parameters needed for soft tokens embedding in soft-prompt
for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
"""
soft_embeds = torch.FloatTensor(self.num_tokens, hidden_dim)
if self.token_init:
soft_embeds.data = torch.clone(self.raw_embedding(torch.tensor([i for i in range(self.num_tokens)])))
else:
soft_embeds = soft_embeds.uniform_(-self.init_range, self.init_range)
self.soft_embeds = nn.Parameter(soft_embeds, requires_grad=True).to(self.device)
class SoftPromptModel(DeltaBase):
r"""
This is the implementation of `The Power of Scale for Parameter-Efficient
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
This template also does not need any textual template. Addition tokens are directly
2022-04-14 11:22:41 +08:00
concatenated into the input ids. There are two initializations of the new tokens.
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
2022-02-14 21:19:03 +08:00
the first n_tokens similar to their implementation).
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
Args:
2022-10-14 23:15:38 +08:00
2022-04-14 11:22:41 +08:00
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
2022-02-14 21:19:03 +08:00
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
2022-03-19 15:04:42 +08:00
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
2022-10-14 23:15:38 +08:00
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the PLM.
other_expand_ids (:obj:`dict`, *optional*, default to ``{'attention_mask':1, 'token_type_ids':0}``): The name of other tokens and its default value that expand along with the input sequence. For example, when you prepend 100 tokens to the input_ids, the attention_mask should be extended, and the token_type_ids should be extended as well.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only the implemented ones).
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen together with the prefix parameters.
2022-03-19 15:04:42 +08:00
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
2022-02-14 21:19:03 +08:00
2022-10-14 23:15:38 +08:00
2022-02-14 21:19:03 +08:00
"""
2022-10-14 23:15:38 +08:00
2022-02-14 21:19:03 +08:00
config_class = SoftPromptConfig
delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used
2022-10-14 23:15:38 +08:00
_need_pseudo_data = False
2022-04-14 11:22:41 +08:00
def __init__(self,
2022-02-14 21:19:03 +08:00
backbone_model: nn.Module,
soft_token_num=100,
init_range = 0.5,
token_init=True,
2022-04-18 23:28:13 +08:00
other_expand_ids={"attention_mask":1, "token_type_ids":0},
2022-03-19 15:04:42 +08:00
modified_modules: Optional[List[str]] = None,
2022-04-14 11:22:41 +08:00
exclude_modules: Optional[List[str]] = None,
2022-03-19 15:04:42 +08:00
unfrozen_modules: Optional[List[str]] = None,
2022-02-14 21:19:03 +08:00
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
2022-04-14 11:22:41 +08:00
DeltaBase.__init__(self,
2022-02-14 21:19:03 +08:00
backbone_model = backbone_model,
modified_modules = ["root"],
2022-04-14 11:22:41 +08:00
exclude_modules = exclude_modules,
2022-02-14 21:19:03 +08:00
unfrozen_modules = unfrozen_modules,
common_structure = False,
interactive_modify = interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
try:
self.__dict__['raw_embedding'] = self.backbone_model.get_input_embeddings()
except AttributeError:
raise AttributeError(f"'{type(self.backbone_model)}' object has no attribute 'get_input_embeddings', please pass "+
"input embeddings into 'self.raw_embedding' in user-specific ways.")
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
2022-04-14 11:22:41 +08:00
def add_all_delta_to_backbone(self,
module: nn.Module,
2022-02-14 21:19:03 +08:00
modified_modules: List[str],
) -> nn.Module:
self.update_module()
self.mark_as_delta()
return module
2022-04-14 11:22:41 +08:00
2022-02-14 21:19:03 +08:00
def update_module(self):
soft_prompt_layer = self.new_module_like(self.raw_embedding)
2022-10-14 23:15:38 +08:00
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,delta_module=soft_prompt_layer,delta_name="soft_prompt_layer" )
2022-02-14 21:19:03 +08:00
def new_module_like(self, module):
module_device = get_device(module)
2022-02-20 17:23:31 +08:00
soft_prompt_layer = SoftPromptLayer(
2022-02-14 21:19:03 +08:00
soft_token_num = self.soft_token_num,
raw_embedding = self.raw_embedding,
2022-04-18 23:28:13 +08:00
other_expand_ids = self.other_expand_ids,
2022-02-14 21:19:03 +08:00
token_init = self.token_init,
init_range = self.init_range,
device = module_device,
)
2022-09-03 18:12:12 +08:00
try:
import bmtrain as bmt
soft_prompt_layer = bmt.BMTrainModelWrapper(soft_prompt_layer)
except:
pass
2022-04-14 11:22:41 +08:00
self.delta_modules.append(soft_prompt_layer)
2022-02-14 21:19:03 +08:00
return soft_prompt_layer