This commit is contained in:
Achazwl 2022-02-20 17:23:31 +08:00
parent 0c590e7965
commit c2e086c6ed
21 changed files with 665 additions and 42 deletions

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "cola",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_matthews_correlation",
"learning_rate": 0.0004,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 80,
"output_dir": "outputs/lora/roberta-base/v2/cola",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 32,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "cola",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "cola",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"warmup_steps": 0,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,46 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0005,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "mnli",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_accuracy",
"learning_rate": 0.0005,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 30,
"output_dir": "outputs/lora/roberta-base/v2/mnli",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 16,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "mnli",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "mnli",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0004,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "mrpc",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_accuracy",
"learning_rate": 0.0004,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 30,
"output_dir": "outputs/lora/roberta-base/v2/mrpc",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 16,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "mrpc",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "mrpc",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0004,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "qnli",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_accuracy",
"learning_rate": 0.0004,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 25,
"output_dir": "outputs/lora/roberta-base/v2/qnli",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 32,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "qnli",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "qnli",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0005,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "qqp",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_accuracy",
"learning_rate": 0.0005,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 25,
"output_dir": "outputs/lora/roberta-base/v2/qqp",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 16,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "qqp",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "qqp",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,46 @@
{
"dataset_config_name": [
"en"
],
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "rte",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_accuracy",
"learning_rate": 0.0005,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 80,
"output_dir": "outputs/lora/roberta-base/rte",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 32,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "rte",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "rte",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0005,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "sst2",
"evaluation_strategy": "epoch",
"metric_for_best_model": "eval_accuracy",
"greater_is_better": true,
"learning_rate": 0.0005,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 60,
"output_dir": "outputs/lora/roberta-base/v2/sst2",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 16,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "sst2",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "sst2",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,47 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0004,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "stsb",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_pearson",
"learning_rate": 0.0004,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 40,
"output_dir": "outputs/lora/roberta-base/v2/stsb",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 16,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "stsb",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "stsb",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -0,0 +1,48 @@
{
"dataset_config_name": [
"en"
],
"delta_lr": 0.0005,
"delta_type": "lora",
"do_eval": true,
"do_test": true,
"do_train": true,
"eval_dataset_config_name": [
"en"
],
"eval_dataset_name": "wnli",
"evaluation_strategy": "epoch",
"greater_is_better": true,
"metric_for_best_model": "eval_pearson",
"learning_rate": 0.0003,
"load_best_model_at_end": true,
"lora_alpha": 8,
"lora_rank": 8,
"max_source_length": 512,
"model_name": "roberta",
"model_name_or_path": "roberta-base",
"non_linearity": "gelu_new",
"num_train_epochs": 30,
"output_dir": "outputs/lora/roberta-base/v2/wnli",
"per_device_eval_batch_size": 100,
"per_device_train_batch_size": 32,
"predict_with_generate": true,
"save_strategy": "epoch",
"save_total_limit": 1,
"split_validation_test": true,
"task_name": "wnli",
"test_dataset_config_name": [
"en"
],
"test_dataset_name": "wnli",
"tokenizer_name": "roberta-base",
"unfrozen_modules": [
"classifier",
"deltas"
],
"warmup_ratio": 0.06,
"warmup_steps": 0,
"weight_decay": 0.1,
"overwrite_output_dir": true,
"push_to_hub": false
}

View File

@ -603,6 +603,7 @@ def main():
item = label_list[item]
writer.write(f"{index}\t{item}\n")
# from IPython import embed; embed()
# kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
# if data_args.task_name is not None:

View File

@ -9,6 +9,7 @@ Visualization(model).structure_graph()
from opendelta import LoraModel
import re
delta_model = LoraModel(backbone_model=model, modified_modules=['[r](\d)+\.output.dense', 'attention.output.dense'])
# delta_model = LoraModel(backbone_model=model, modified_modules=['[r][0-5]\.output.dense'])
print("after modify")
delta_model.log()
# This will visualize the backbone after modification and other information.

View File

@ -480,7 +480,7 @@ class DeltaBase(nn.Module, SaveLoadMixin):
"""
raise NotImplementedError
def insert_sequential_module(self, module, delta_module=None, name='delta', strict=False, _delta_info=None):
def insert_sequential_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
r"""insert a module (previous not exists in the code base) before/after a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the new module's forward function and then pass
it into the original ones. The new module can also be inserted after the original module with similar mechanism.
@ -520,14 +520,14 @@ class DeltaBase(nn.Module, SaveLoadMixin):
_delta_info = {"method": "insert_sequential",
"delta_module": delta_module,
"delta_name": name,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
name = _delta_info["delta_name"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
@ -538,19 +538,58 @@ class DeltaBase(nn.Module, SaveLoadMixin):
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
def insert_parrellel_module(self, module, pre_caller=None, post_caller=None, delta_module=None, name='delta'):
def insert_parallel_module(self, module, delta_module=None, delta_name='delta', strict=False, _delta_info=None):
"""insert a module (previous not exists in the code base) across a module. Specifically, it modifies the forward
function of the original module to firstly pass the arguments into the delta model's forward function and set
aside the calculation result. Then combine it with the calculation result output from the backbone module.
When implementing the new module , researchers should be aware of the arguments and keywards of the original module's forward function.
# TODO: currently not in use.
Args:
module: (:obj:`nn.Module`): The (sub)module to inserted a delta module.
delta_module: (:obj:`DeltaBase`): The delta module to be inserted.
name: (:obj:`str`, *optional*): The name of the delta in the backbone module.
strict: (:obj:`bool`, *optional*): Whether to prohibit modify a modified module.
_delta_info (:obj:`Dict`, *optional*): Used in attach(), reattach a delta module to backbone. The info of
original delta is passed through ``_delta_info``.
"""
raise NotImplementedError
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
if strict:
if hasattr(module.forward, "__wrapped__"):
raise RuntimeWarning("The forward function might have been wrapped by a decorator, is it intended?")
# record info for plug and unplug and nested wrap
if _delta_info is None:
if delta_module is None:
raise RuntimeError("delta module can't be none to ensure successful replicate of the parent module.")
_delta_info = {"method": "insert_parallel",
"delta_module": delta_module,
"delta_name": delta_name,
"delta_belong": self,
"state": "on"}
self._register_delta_infos(parent_module=module,
_delta_info = _delta_info)
else:
delta_module = _delta_info["delta_module"]
delta_name = _delta_info["delta_name"]
setattr(module, _delta_info['delta_name'], _delta_info["delta_module"])
new_forward = decorate(module.forward, _caller, extras=(module, _delta_info['delta_name']), kwsyntax=True) # decorator.decorate helps preserving the functions metadata (signature, etc.).
module.forward = new_forward.__get__(module, type(module)) # func.__get__(object, type(object)) register a function as an object's method
# for DataParallel's copy behavior. Experimental:
# may have bugs when module.forward is nestedly wrapped.
module._replicate_for_data_parallel = new_replicate_for_data_parallel.__get__(module, type(module))
def set_active_state_dict(self, module: nn.Module):
r"""modify the state_dict function of the model (by default, the backbone model) to return only the tunable part.

View File

@ -192,7 +192,7 @@ class AdapterModel(DeltaBase):
def update_module(self, module: nn.Module, key: str):
_, _, ref = self.find_module(module, key)
adapterlayer = self.new_module_like(ref)
self.insert_sequential_module(ref, delta_module=adapterlayer, name="adapter")
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="adapter")
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -179,7 +179,7 @@ class BitFitModel(DeltaBase):
def add_bias_to_others(self, c):
new_bias = BiasLayer()
self.insert_sequential_module(c, delta_module=new_bias, name="bitfit") # name shouldn't be `bias` here, since
self.insert_sequential_module(c, delta_module=new_bias, delta_name="bitfit") # name shouldn't be `bias` here, since
# the name `bias` is reserved for some module such as roberta's LayerNorm.
self.delta_modules.append(new_bias)

View File

@ -277,7 +277,7 @@ class CompacterModel(DeltaBase):
adapterlayer = self.new_module_like(ref)
self.insert_sequential_module(ref,
delta_module=adapterlayer,
name="compactor")
delta_name="compactor")
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -1,3 +1,4 @@
from turtle import forward
from typing import Optional, Union
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
@ -7,6 +8,42 @@ from transformers.models.t5 import T5ForConditionalGeneration
import loralib as lora
import torch.nn as nn
from opendelta import BaseDeltaConfig
import math
class LowRankLinear(nn.Module):
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
# copy from loralib and do some refactor
def __init__(self,
in_features,
out_features,
weight,
r=8,
lora_alpha=16,
lora_dropout=0.0,
):
super().__init__()
self.r = r
self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout
self.lin = nn.Linear(in_features, out_features) #
if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
if r > 0:
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
self.lin.reset_parameters() #
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def forward(self, x):
return (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
class LoraConfig(BaseDeltaConfig):
r"""
@ -27,7 +64,6 @@ class LoraConfig(BaseDeltaConfig):
setattr(self, arg_name, locals()[arg_name])
class LoraModel(DeltaBase):
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
@ -89,11 +125,10 @@ class LoraModel(DeltaBase):
)
def update_module(self, module: nn.Module, key: str):
parent_ref, child_name, child_ref = self.find_module(module, key)
new_module = self.new_module_like(child_module=child_ref)
self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
parallel_module = self.new_module_like(child_module=child_ref)
self.insert_parallel_module(child_ref, delta_module=parallel_module, delta_name="lora")
def _pseudo_data_to_instantiate(self, module):
# no need to pass pseudo input, so overwrite it
@ -102,26 +137,13 @@ class LoraModel(DeltaBase):
def new_module_like(self, child_module):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = lora.Linear(in_features=in_features,
out_features=out_features,
new_module = LowRankLinear(in_features = in_features,
out_features = out_features,
weight = child_module.weight,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
new_module.weight = child_module.weight
new_module.bias = child_module.bias # if bias is None, also copy
self.delta_modules.append(new_module)
else:
raise NotImplementedError
return new_module
def mark_as_delta(self, module: nn.Module = None):
if module is None:
module=self
for n, p in module.named_parameters():
param_name = n.split(".")[-1]
if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
setattr(p, "_is_delta", True)

View File

@ -0,0 +1,126 @@
from typing import Optional, Union
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import *
from opendelta.basemodel import DeltaBase
from transformers.models.t5 import T5ForConditionalGeneration
import loralib as lora
import torch.nn as nn
from opendelta import BaseDeltaConfig
class LoraConfig(BaseDeltaConfig):
r"""
This is the configuration class to store the configuration of a :py:class:`~LoraModel`
"""
def __init__(
self,
lora_r=8,
lora_alpha=16,
lora_dropout=0.0,
**kwargs
):
super().__init__(**kwargs)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # the arg has not been registered in parent config
setattr(self, arg_name, locals()[arg_name])
class LoraModel(DeltaBase):
r""" The implementation of `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_ .
Thanks for their `loralib <https://github.com/microsoft/LoRA/tree/main/loralib>`_, we use loralib.linear
to replace the linear layer of the backbone model.
class attributes:
- default_modified_modules = ['attn.q', 'attn.v'] According to the paper, they modify q and v matrix in the
attention layer. However, other linears can also be modified, and may lead to better performance.
.. note::
modified_modules should point to linear layer. We currently don't support broadcast to all linears in
a module's child modules.
- delta_type = "lora"
Args:
backbone_model (:obj:`transformers.PretrainedModels`): The backbone model to be modified.
lora_r (:obj:`int`, *optional*): the rank of the lora parameters. The smaller lora_r is , the fewer parameters lora has.
lora_alpha (:obj:`bool`, *optional*): A hyper-parameter to control the init scale of loralib.linear .
lora_dropout (:obj:`bool`, *optional*): The dropout rate in lora.linear.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
together with the prefix parameters.
common_structure (:obj:`bool`): whether using name-based addressing witha common structure mapping.
"""
config_class = LoraConfig
delta_type = "lora"
default_modified_modules = ['attn.q', 'attn.v']
def __init__(self,
backbone_model: nn.Module,
lora_r=8,
lora_alpha=16,
lora_dropout=0.0,
modified_modules: Optional[bool] = None,
unfrozen_modules: Optional[bool] = None,
common_structure: Optional[bool] = None,
interactive_modify: Optional[Union[bool, int]] = False,
):
DeltaBase.__init__(self,
backbone_model,
modified_modules=modified_modules,
unfrozen_modules=unfrozen_modules,
common_structure=common_structure,
interactive_modify=interactive_modify,
)
arg_names = get_arg_names_inside_func(self.__init__)
for arg_name in arg_names:
if not hasattr(self, arg_name): # not registered in parent class
setattr(self, arg_name, locals()[arg_name])
self.delta_modules = nn.ModuleList()
self.add_all_delta_to_backbone(self.backbone_model,
self.modified_modules,
)
def update_module(self, module: nn.Module, key: str):
parent_ref, child_name, child_ref = self.find_module(module, key)
new_module = self.new_module_like(child_module=child_ref)
self.replace_module(parent_ref, child_name, child_ref, new_module, delta_name="lora")
def _pseudo_data_to_instantiate(self, module):
# no need to pass pseudo input, so overwrite it
pass
def new_module_like(self, child_module):
if isinstance(child_module, nn.Linear):
in_features, out_features = child_module.in_features, child_module.out_features
new_module = lora.Linear(in_features=in_features,
out_features=out_features,
r=self.lora_r,
lora_alpha=self.lora_alpha,
lora_dropout=self.lora_dropout)
new_module.weight = child_module.weight
new_module.bias = child_module.bias # if bias is None, also copy
else:
raise NotImplementedError
return new_module
def mark_as_delta(self, module: nn.Module = None):
if module is None:
module=self
for n, p in module.named_parameters():
param_name = n.split(".")[-1]
if "lora_A" in param_name or "lora_B" in param_name: # only lora_A, lora_B is the delta parameter.
setattr(p, "_is_delta", True)

View File

@ -194,7 +194,7 @@ class LowRankAdapterModel(DeltaBase):
def update_module(self, module: nn.Module, key: str):
_, _, ref = self.find_module(module, key)
adapterlayer = self.new_module_like(ref)
self.insert_sequential_module(ref, delta_module=adapterlayer, name="low_rank_adapter")
self.insert_sequential_module(ref, delta_module=adapterlayer, delta_name="low_rank_adapter")
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -512,7 +512,7 @@ class PrefixModel(DeltaBase):
module_list=self.delta_modules)
self.delta_modules = None
self.reparams = reparams
self.insert_sequential_module(first_modified_module, delta_module=reparams, name="reparams", strict=False)
self.insert_sequential_module(first_modified_module, delta_module=reparams, delta_name="reparams", strict=False)
self.mark_as_delta()
return module
@ -522,7 +522,7 @@ class PrefixModel(DeltaBase):
_, _, ref = self.find_module(module, key)
prefixlayer, ref = self.new_module_like(ref)
self.insert_sequential_module(ref, delta_module=prefixlayer, name="prefix")
self.insert_sequential_module(ref, delta_module=prefixlayer, delta_name="prefix")
self.delta_modules.append(prefixlayer)
def new_module_like(self, module):

View File

@ -193,7 +193,7 @@ class SoftPromptModel(DeltaBase):
soft_prompt_layer = self.new_module_like(self.raw_embedding)
self.insert_sequential_module(self.backbone_model.get_encoder() if self.backbone_model.config.is_encoder_decoder else self.backbone_model,
delta_module=soft_prompt_layer,
name="soft_prompt_layer" )
delta_name="soft_prompt_layer" )
def new_module_like(self, module):
module_device = get_device(module)

View File

@ -8,7 +8,7 @@ def new_replicate_for_data_parallel(self):
r""" self is the parent module.
"""
# rewrite the replicate in DataParallel.
def _caller(_org_func, org_module, delta_name, *args, **kwargs):
def _sequential_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
if hasattr(delta_module, "pre_forward"):
@ -17,6 +17,13 @@ def new_replicate_for_data_parallel(self):
if hasattr(delta_module, "post_forward"):
ret = delta_module.post_forward(ret)
return ret
def _parallel_caller(_org_func, org_module, delta_name, *args, **kwargs):
args = args[1:] # the first argument here is ``self``
delta_module = getattr(org_module, delta_name)
ret_1 = _org_func(*args, **kwargs)
ret_2 = delta_module.forward(*args, **kwargs)
return ret_1 + ret_2
replica = self.__new__(type(self))
org_forward = replica.forward
replica.__dict__ = self.__dict__.copy()
@ -25,8 +32,13 @@ def new_replicate_for_data_parallel(self):
for _delta_info in self._delta_infos:
if _delta_info['method'] == "insert_sequential" and _delta_info['state'] == "on":
new_forward = decorate(replica.forward, _caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
if _delta_info['state'] == 'on':
if _delta_info['method'] == "insert_sequential":
new_forward = decorate(replica.forward, _sequential_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
elif _delta_info['method'] == "insert_parallel":
new_forward = decorate(replica.forward, _parallel_caller, extras=(replica, _delta_info['delta_name']), kwsyntax=True)
else:
raise NotImplementedError(f"data_parallel for _delta_info['method']=='{_delta_info['method']}' is not supported")
replica.__dict__['forward'] = new_forward.__get__(replica, type(replica))
# replicas do not have parameters themselves, the replicas reference the original