OpenDeltaMirror/opendelta/delta_models/soft_prompt.py

230 lines
10 KiB
Python

from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device
from opendelta.basemodel import DeltaBase
from typing import *
import torch
import torch.nn as nn
from opendelta import BaseDeltaConfig
from opendelta import logging
logger = logging.get_logger(__name__)
class SoftPromptConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`SoftPromptModel`
"""
def __init__(
self,
soft_token_num=100,
init_range = 0.5,
token_init = True,
**kwargs
):
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
class SoftPromptLayer(nn.Module):
r"""This is the implementation of `The Power of Scale for Parameter-Efficient
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
This template also does not need any textual template. Addition tokens are directly
concatenated into the input ids. There are two initializations of the new tokens.
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
the first n_tokens similar to their implementation).
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
"""
def __init__(self,
soft_token_num: int = 100,
raw_embedding: Optional[torch.Tensor] = None,
init_range: Optional[float] = 0.5,
other_expand_ids: Optional[Dict] = {"attention_mask":1, "token_type_ids":0},
token_init = False,
pad_id = 0,
device: Optional[str]=None,
):
super().__init__()
self.__dict__['raw_embedding'] = raw_embedding
self.init_range = init_range
self.num_tokens = soft_token_num
self.pad_id = pad_id
self.token_init = token_init
self.device = device
self.other_expand_ids = other_expand_ids
assert self.num_tokens>0
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
# self.all_pseudo_tokens = {}
def pre_forward(self, *args, **kwargs):
# if attention_mask is passed as PLM's input, modify it here
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
# In generation, the input is forward through the model again.
return args, kwargs
if 'input_ids' in kwargs:
input_ids = kwargs['input_ids']
kwargs['input_ids'] = None
elif len(args) > 0:
input_ids = args[0]
args = args[1:]
else:
input_ids = None
if 'attention_mask' not in kwargs or kwargs['attention_mask'] is None:
# infer attention mask
if input_ids is None:
raise RuntimeError("no input ids found")
kwargs['attention_mask'] = (input_ids != self.pad_id).to(torch.int64)
if 'inputs_embeds' not in kwargs or kwargs['inputs_embeds'] is None:
try:
inputs_embeds = self.raw_embedding(input_ids)
except:
raise RuntimeError("neither inputs_embeds nor input_ids is specified.")
else:
inputs_embeds = kwargs['inputs_embeds']
batch_size = inputs_embeds.size(0)
soft_embeds = self.soft_embeds.repeat(batch_size, 1, 1)
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
kwargs['inputs_embeds'] = inputs_embeds
for expand_key in self.other_expand_ids:
if expand_key in kwargs:
real_tokens = kwargs[expand_key]
# if expand_key in self.all_pseudo_tokens:
# pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
# else:
pseudo_tokens_value = self.other_expand_ids[expand_key]
pseudo_tokens = torch.ones(
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
dtype = real_tokens.dtype,
device=real_tokens.device) * pseudo_tokens_value
# self.all_pseudo_tokens[expand_key] = pseudo_tokens
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
return args, kwargs
def instantiate(self, hidden_dim) -> None:
"""
generate parameters needed for soft tokens embedding in soft-prompt
for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens
"""
soft_embeds = torch.FloatTensor(self.num_tokens, hidden_dim)
if self.token_init:
soft_embeds.data = torch.clone(self.raw_embedding(torch.tensor([i for i in range(self.num_tokens)])))
else:
soft_embeds = soft_embeds.uniform_(-self.init_range, self.init_range)
self.soft_embeds = nn.Parameter(soft_embeds, requires_grad=True).to(self.device)
class SoftPromptModel(DeltaBase):
r"""
This is the implementation of `The Power of Scale for Parameter-Efficient
Prompt Tuning <https://arxiv.org/pdf/2104.08691v1.pdf>`_ . Similar to :obj:`PrefixTuningTemplate`,
This template also does not need any textual template. Addition tokens are directly
concatenated into the input ids. There are two initializations of the new tokens.
(1). random initialization. (2) initialize with the tokens of the plm (We simply take
the first n_tokens similar to their implementation).
Note that this template can be simply achieved by :obj:`SoftManualTemplate`, in which
you set ``n_token`` <soft> tokens template before the <text_a> will give the same result.
Args:
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
other_expand_ids (:obj:`dict`, *optional*, default to `{"attention_mask":1, "token_type_ids":0}`) The name of
other tokens and its default value that expand along with the input sequence. For example, when
you prepend 100 tokens to the input_ids, the attention_mask should be extended, and the token_type_ids should
be extended as well.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
"""
config_class = SoftPromptConfig
delta_type = "soft_prompt"
default_modified_modules = ["root"] # not used
def __init__(self,
backbone_model: nn.Module,
soft_token_num=100,
init_range = 0.5,
token_init=True,
other_expand_ids={"attention_mask":1, "token_type_ids":0},
modified_modules: Optional[List[str]] = None,
exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[List[str]] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
DeltaBase.__init__(self,
backbone_model = backbone_model,
modified_modules = ["root"],
exclude_modules = exclude_modules,
unfrozen_modules = unfrozen_modules,
common_structure = False,
interactive_modify = interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
try:
self.__dict__['raw_embedding'] = self.backbone_model.get_input_embeddings()
except AttributeError:
raise AttributeError(f"'{type(self.backbone_model)}' object has no attribute 'get_input_embeddings', please pass "+
"input embeddings into 'self.raw_embedding' in user-specific ways.")
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
def add_all_delta_to_backbone(self,
module: nn.Module,
modified_modules: List[str],
) -> nn.Module:
self.update_module()
self.mark_as_delta()
return module
def update_module(self):
soft_prompt_layer = self.new_module_like(self.raw_embedding)
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
delta_module=soft_prompt_layer,
delta_name="soft_prompt_layer" )
def new_module_like(self, module):
module_device = get_device(module)
soft_prompt_layer = SoftPromptLayer(
soft_token_num = self.soft_token_num,
raw_embedding = self.raw_embedding,
other_expand_ids = self.other_expand_ids,
token_init = self.token_init,
init_range = self.init_range,
device = module_device,
)
self.delta_modules.append(soft_prompt_layer)
return soft_prompt_layer