2022-02-14 21:19:03 +08:00
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
|
|
|
from opendelta.utils.name_based_addressing import *
|
|
|
|
from opendelta.basemodel import DeltaBase
|
|
|
|
import torch.nn as nn
|
|
|
|
from opendelta import BaseDeltaConfig
|
2022-02-20 17:23:31 +08:00
|
|
|
import math
|
2022-07-01 22:23:02 +08:00
|
|
|
from dataclasses import dataclass, field
|
2022-02-20 17:23:31 +08:00
|
|
|
|
|
|
|
class LowRankLinear(nn.Module):
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
|
|
# ------------------------------------------------------------------------------------------
|
|
|
|
# copy from loralib and do some refactor
|
|
|
|
def __init__(self,
|
|
|
|
in_features,
|
|
|
|
out_features,
|
|
|
|
weight,
|
2022-04-14 11:22:41 +08:00
|
|
|
r=8,
|
2022-02-20 17:23:31 +08:00
|
|
|
lora_alpha=16,
|
|
|
|
lora_dropout=0.0,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.r = r
|
|
|
|
self.lora_alpha = lora_alpha
|
|
|
|
self.lora_dropout = lora_dropout
|
|
|
|
if lora_dropout > 0.:
|
|
|
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
|
|
else:
|
|
|
|
self.lora_dropout = lambda x: x
|
|
|
|
if r > 0:
|
|
|
|
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
|
|
|
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
|
|
|
self.scaling = self.lora_alpha / self.r
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
|
|
|
|
2022-07-01 22:23:02 +08:00
|
|
|
@dataclass
|
|
|
|
class LoraArguments:
|
|
|
|
r: int = 8
|
|
|
|
lora_alpha: int = 16
|
|
|
|
lora_dropout: float = 0.0
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
class LoraConfig(BaseDeltaConfig):
|
|
|
|
r"""
|
|
|
|
This is the configuration class to store the configuration of a :py:class:`~LoraModel`
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
2022-04-14 11:22:41 +08:00
|
|
|
self,
|
2022-02-14 21:19:03 +08:00
|
|
|
lora_r=8,
|
|
|
|
lora_alpha=16,
|
|
|
|
lora_dropout=0.0,
|
|
|
|
**kwargs
|
2022-04-14 11:22:41 +08:00
|
|
|
):
|
2022-02-14 21:19:03 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # the arg has not been registered in parent config
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
|
|
|
|
class LoraModel(DeltaBase):
|
|
|
|
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
|
2022-03-13 22:04:38 +08:00
|
|
|
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-03-13 22:04:38 +08:00
|
|
|
.. note::
|
2022-10-14 23:15:38 +08:00
|
|
|
|
2022-03-13 22:04:38 +08:00
|
|
|
In our implementation, we did not use loralib.linear to replace the linear layer of the backbone model.
|
|
|
|
Instead, we insert a parallel module into the backbone.
|
2022-10-14 23:15:38 +08:00
|
|
|
In other words, we treat :math:`(W + A^TB) X` as :math:`WX+ A^TBX`, and insert the :math:`A^TBX` as a parallel insertion module. If you want to use the original implementation, please refer to `lora_old.py`
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
class attributes:
|
2022-10-14 23:15:38 +08:00
|
|
|
|
|
|
|
- default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the attention layer. However, other linears can also be modified, and may lead to better performance.
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
.. note::
|
2022-10-14 23:15:38 +08:00
|
|
|
|
2022-04-14 11:22:41 +08:00
|
|
|
modified_modules should point to linear layer. We currently don't support broadcast to all linears in
|
2022-02-14 21:19:03 +08:00
|
|
|
a module's child modules.
|
|
|
|
|
|
|
|
- delta_type = "lora"
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
|
2022-02-14 21:19:03 +08:00
|
|
|
lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.
|
2022-03-19 15:04:42 +08:00
|
|
|
lora_alpha (:obj:`int`, *optional*): A hyper-parameter to control the init scale of loralib.linear .
|
|
|
|
lora_dropout (:obj:`float`, *optional*): The dropout rate in lora.linear.
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
|
|
|
the implemented ones)
|
|
|
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
|
|
|
together with the prefix parameters.
|
2022-03-19 15:04:42 +08:00
|
|
|
common_structure (:obj:`bool`): whether using name-based addressing with a common structure mapping.
|
2022-10-23 16:42:21 +08:00
|
|
|
backend (:obj:`str`): choose the backend of plm, 'hf' for huggingface transformers,'bmt' for bmtrain
|
2022-02-14 21:19:03 +08:00
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
config_class = LoraConfig
|
|
|
|
delta_type = "lora"
|
2022-10-12 01:36:38 +08:00
|
|
|
default_modified_modules = ['attn@.q@', 'attn@.v@']
|
2022-10-23 16:42:21 +08:00
|
|
|
_supported_backends = ['hf', 'bmt']
|
2022-10-12 01:36:38 +08:00
|
|
|
_need_pseudo_data = False
|
2022-02-14 21:19:03 +08:00
|
|
|
def __init__(self,
|
2022-04-14 11:22:41 +08:00
|
|
|
backbone_model: nn.Module,
|
2022-02-14 21:19:03 +08:00
|
|
|
lora_r=8,
|
|
|
|
lora_alpha=16,
|
|
|
|
lora_dropout=0.0,
|
2022-04-14 11:22:41 +08:00
|
|
|
modified_modules: Optional[List[str]] = None,
|
|
|
|
unfrozen_modules: Optional[List[str]] = None,
|
|
|
|
exclude_modules: Optional[List[str]] = None,
|
2022-02-14 21:19:03 +08:00
|
|
|
common_structure: Optional[bool] = None,
|
|
|
|
interactive_modify: Optional[Union[bool, int]] = False,
|
2022-10-23 16:42:21 +08:00
|
|
|
backend: Optional[str] = "hf",
|
2022-02-14 21:19:03 +08:00
|
|
|
):
|
2022-04-14 11:22:41 +08:00
|
|
|
DeltaBase.__init__(self,
|
|
|
|
backbone_model,
|
2022-02-14 21:19:03 +08:00
|
|
|
modified_modules=modified_modules,
|
|
|
|
unfrozen_modules=unfrozen_modules,
|
|
|
|
common_structure=common_structure,
|
|
|
|
interactive_modify=interactive_modify,
|
2022-10-23 16:42:21 +08:00
|
|
|
backend=backend,
|
2022-02-14 21:19:03 +08:00
|
|
|
)
|
|
|
|
arg_names = get_arg_names_inside_func(self.__init__)
|
|
|
|
for arg_name in arg_names:
|
|
|
|
if not hasattr(self, arg_name): # not registered in parent class
|
|
|
|
setattr(self, arg_name, locals()[arg_name])
|
|
|
|
|
|
|
|
self.delta_modules = nn.ModuleList()
|
|
|
|
|
|
|
|
self.add_all_delta_to_backbone(self.backbone_model,
|
|
|
|
self.modified_modules,
|
|
|
|
)
|
2022-04-14 11:22:41 +08:00
|
|
|
|
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def update_module(self, module: nn.Module, key: str):
|
|
|
|
parent_ref, child_name, child_ref = self.find_module(module, key)
|
2022-02-20 17:23:31 +08:00
|
|
|
parallel_module = self.new_module_like(child_module=child_ref)
|
|
|
|
self.insert_parallel_module(child_ref, delta_module=parallel_module, delta_name="lora")
|
2022-04-14 11:22:41 +08:00
|
|
|
|
2022-02-14 21:19:03 +08:00
|
|
|
def _pseudo_data_to_instantiate(self, module):
|
|
|
|
# no need to pass pseudo input, so overwrite it
|
|
|
|
pass
|
|
|
|
|
|
|
|
def new_module_like(self, child_module):
|
2022-09-03 18:12:12 +08:00
|
|
|
in_features, out_features = child_module.in_features, child_module.out_features
|
|
|
|
new_module = LowRankLinear(in_features = in_features,
|
|
|
|
out_features = out_features,
|
|
|
|
weight = child_module.weight,
|
|
|
|
r=self.lora_r,
|
|
|
|
lora_alpha=self.lora_alpha,
|
|
|
|
lora_dropout=self.lora_dropout)
|
2022-10-23 16:42:21 +08:00
|
|
|
if self.backend == "bmt":
|
2022-09-03 18:12:12 +08:00
|
|
|
import bmtrain as bmt
|
|
|
|
new_module = bmt.BMTrainModelWrapper(new_module)
|
2022-10-23 16:42:21 +08:00
|
|
|
|
2022-09-03 18:12:12 +08:00
|
|
|
self.delta_modules.append(new_module)
|
2022-04-14 14:33:03 +08:00
|
|
|
return new_module
|