fixbug
This commit is contained in:
parent
da6efddc95
commit
27de7e0ac6
|
@ -60,7 +60,7 @@ def get_prompts(task, tokenizer, data_args, template_id="0", verbalizer_id="0"):
|
|||
template = ManualTemplate(text = task.templates_text[template_id])
|
||||
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||
from IPython import embed; embed()
|
||||
# from IPython import embed; embed()
|
||||
return template, verbalizer, tokenizer_wrapper
|
||||
|
||||
class DataCollator(HfDataCollatorMixin):
|
||||
|
|
|
@ -14,253 +14,6 @@ parser.add_argument("--job", type=str)
|
|||
parser.add_argument("--")
|
||||
args = parser.parse_args()
|
||||
|
||||
BaseConfigs = {}
|
||||
BaseConfigs['t5-base'] = {
|
||||
("job_name", "task_name", "eval_dataset_name", "test_dataset_name", "num_train_epochs",
|
||||
"max_source_length",
|
||||
"per_device_train_batch_size", "per_device_eval_batch_size", "warmup_steps","save_steps", "eval_steps"): zip(
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record",
|
||||
"superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
[ 20, 20, 40, 20, 3, 3, 20, 20, 20, 3, 3, 20, 3, 3, 20],
|
||||
[256, 256, 256, 256, 256, 512, 256, 128, 128, 128, 128, 128, 128, 128, 128],
|
||||
[ 32, 32, 32, 32, 32, 16, 32] + [32] * 8,
|
||||
[ 32, 32, 32, 32, 32, 16, 32] + [32] * 8,
|
||||
[0] *7 +[0] *8,
|
||||
[200, 100, 50, 100, 200, 200, 100, 200, 100, 200, 200, 100, 200, 200, 100],
|
||||
[200, 100, 50, 100, 200, 200, 100, 200, 100, 200, 200, 100, 200, 200, 100],
|
||||
),
|
||||
"do_train": True,
|
||||
"do_eval": True,
|
||||
"do_test": True,
|
||||
|
||||
"model_name_or_path": f"{PATHBASE}t5-base",
|
||||
"tokenizer_name": f"{PATHBASE}t5-base",
|
||||
"save_total_limit": 1,
|
||||
# For glue datasets.
|
||||
"split_validation_test": True,
|
||||
"seed": 42,
|
||||
"dataset_config_name": ["en"],
|
||||
"eval_dataset_config_name": ["en"],
|
||||
"test_dataset_config_name": ["en"],
|
||||
# other configurations.
|
||||
"predict_with_generate": True,
|
||||
# To evaluate during training.
|
||||
"load_best_model_at_end": True,
|
||||
"metric_for_best_model": "average_metrics",
|
||||
"greater_is_better": True,
|
||||
"evaluation_strategy": "steps",
|
||||
"overwrite_output_dir": True,
|
||||
"push_to_hub": False,
|
||||
"push_to_delta_center": True,
|
||||
"save_strategy": "steps"
|
||||
}
|
||||
|
||||
AllConfigs['bitfit_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['bitfit_t5-base'].update({
|
||||
"delta_type": "bitfit",
|
||||
"learning_rate": 3e-4,
|
||||
"output_dir": "outputs/bitfit/t5-base/",
|
||||
})
|
||||
|
||||
AllConfigs['adapter_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['adapter_t5-base'].update({
|
||||
"delta_type": "adapter",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
"layer_norm",
|
||||
"final_layer_norm"
|
||||
],
|
||||
"bottleneck_dim":24,
|
||||
"output_dir": "outputs/adapter/t5-base/",
|
||||
})
|
||||
|
||||
AllConfigs['lora_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['lora_t5-base'].update({
|
||||
"delta_type": "lora",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
"layer_norm",
|
||||
"final_layer_norm"
|
||||
],
|
||||
"lora_r": 8,
|
||||
"output_dir": "outputs/lora/t5-base/",
|
||||
})
|
||||
|
||||
AllConfigs['compacter_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['compacter_t5-base'].update({
|
||||
"delta_type": "compacter",
|
||||
"learning_rate": 3e-3,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
"layer_norm",
|
||||
"final_layer_norm"
|
||||
],
|
||||
"output_dir": "outputs/compacter/t5-base/",
|
||||
"non_linearity": "gelu_new",
|
||||
|
||||
#Compacter.
|
||||
"hypercomplex_division": 4,
|
||||
"hypercomplex_adapters": True,
|
||||
"hypercomplex_nonlinearity": "glorot-uniform",
|
||||
# gradient clip and clamp
|
||||
"gradient_clip": False,
|
||||
"phm_clamp": False,
|
||||
"normalize_phm_weight": False,
|
||||
"learn_phm": True,
|
||||
# shared one side
|
||||
"factorized_phm": True,
|
||||
"shared_phm_rule": False,
|
||||
"factorized_phm_rule": False,
|
||||
"phm_c_init": "normal",
|
||||
"phm_init_range": 0.0001,
|
||||
"use_bias_down_sampler": True,
|
||||
"use_bias_up_sampler": True,
|
||||
})
|
||||
|
||||
AllConfigs['compacter++_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['compacter++_t5-base'].update({
|
||||
"delta_type": "compacter",
|
||||
"learning_rate": 3e-3,
|
||||
"do_train": True,
|
||||
"do_eval": True,
|
||||
"do_test": True,
|
||||
"modified_modules": [
|
||||
"DenseReluDense"
|
||||
],
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
"layer_norm",
|
||||
"final_layer_norm"
|
||||
],
|
||||
"output_dir": "outputs/compacter++/t5-base/",
|
||||
"non_linearity": "gelu_new",
|
||||
|
||||
#Compacter.
|
||||
"hypercomplex_division": 4,
|
||||
"hypercomplex_adapters": True,
|
||||
"hypercomplex_nonlinearity": "glorot-uniform",
|
||||
# gradient clip and clamp
|
||||
"gradient_clip": False,
|
||||
"phm_clamp": False,
|
||||
"normalize_phm_weight": False,
|
||||
"learn_phm": True,
|
||||
# shared one side
|
||||
"factorized_phm": True,
|
||||
"shared_phm_rule": False,
|
||||
"factorized_phm_rule": False,
|
||||
"phm_c_init": "normal",
|
||||
"phm_init_range": 0.0001,
|
||||
"use_bias_down_sampler": True,
|
||||
"use_bias_up_sampler": True,
|
||||
})
|
||||
|
||||
|
||||
AllConfigs['low_rank_adapter_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['low_rank_adapter_t5-base'].update({
|
||||
"delta_type": "low_rank_adapter",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
"layer_norm",
|
||||
"final_layer_norm"
|
||||
],
|
||||
"output_dir": "outputs/low_rank_adapter/t5-base/",
|
||||
"non_linearity": "gelu_new",
|
||||
"low_rank_w_init": "glorot-uniform",
|
||||
"low_rank_rank": 1,
|
||||
})
|
||||
|
||||
|
||||
AllConfigs['soft_prompt_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['soft_prompt_t5-base'].update({
|
||||
"delta_type": "soft_prompt",
|
||||
"learning_rate": 3e-2,
|
||||
"soft_token_num":100,
|
||||
"token_init": False,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
],
|
||||
"output_dir": "outputs/soft_prompt/t5-base/",
|
||||
})
|
||||
|
||||
AllConfigs['prefix_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['prefix_t5-base'].update({
|
||||
"delta_type": "prefix",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
],
|
||||
"output_dir": "outputs/prefix/t5-base/",
|
||||
})
|
||||
|
||||
AllConfigs['soft_prompt_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
AllConfigs['soft_prompt_t5-base'].update({
|
||||
"delta_type": "soft_prompt",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
],
|
||||
"output_dir": "outputs/soft_prompt/t5-base/",
|
||||
})
|
||||
#### T5-base
|
||||
BaseConfigs['t5-small'] = {
|
||||
("job_name", "task_name", "eval_dataset_name", "test_dataset_name", "num_train_epochs",
|
||||
"max_source_length",
|
||||
"per_device_train_batch_size", "per_device_eval_batch_size", "warmup_steps","save_steps", "eval_steps"): zip(
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record",
|
||||
"superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
["superglue-boolq", "superglue-cb", "superglue-copa", "superglue-wic", "superglue-multirc", "superglue-record", "superglue-wsc.fixed", "mrpc", "cola", "sst2", "qnli", "rte", "mnli", "qqp", "stsb"],
|
||||
[ 20, 20, 40, 20, 3, 3, 20, 20, 20, 3, 3, 20, 3, 3, 20],
|
||||
[256, 256, 256, 256, 256, 512, 256, 128, 128, 128, 128, 128, 128, 128, 128],
|
||||
[ 32, 32, 32, 32, 32, 16, 32] + [32] * 8,
|
||||
[ 32, 32, 32, 32, 32, 16, 32] + [32] * 8,
|
||||
[0] *7 +[0] *8,
|
||||
[200, 100, 50, 100, 200, 200, 100, 200, 100, 200, 200, 100, 200, 200, 100],
|
||||
[200, 100, 50, 100, 200, 200, 100, 200, 100, 200, 200, 100, 200, 200, 100],
|
||||
),
|
||||
"do_train": True,
|
||||
"do_eval": True,
|
||||
"do_test": True,
|
||||
|
||||
"model_name_or_path": f"{PATHBASE}t5-small",
|
||||
"tokenizer_name": f"{PATHBASE}t5-small",
|
||||
"save_total_limit": 1,
|
||||
# For glue datasets.
|
||||
"split_validation_test": True,
|
||||
"seed": 42,
|
||||
"dataset_config_name": ["en"],
|
||||
"eval_dataset_config_name": ["en"],
|
||||
"test_dataset_config_name": ["en"],
|
||||
# other configurations.
|
||||
"predict_with_generate": True,
|
||||
# To evaluate during training.
|
||||
"load_best_model_at_end": True,
|
||||
"metric_for_best_model": "average_metrics",
|
||||
"greater_is_better": True,
|
||||
"evaluation_strategy": "steps",
|
||||
"overwrite_output_dir": True,
|
||||
"push_to_hub": False,
|
||||
"push_to_delta_center": True,
|
||||
"save_strategy": "steps"
|
||||
}
|
||||
|
||||
AllConfigs['prefix_t5-small'] = copy.deepcopy(BaseConfigs['t5-small'])
|
||||
AllConfigs['prefix_t5-small'].update({
|
||||
"delta_type": "prefix",
|
||||
"learning_rate": 3e-4,
|
||||
"unfrozen_modules": [
|
||||
"deltas",
|
||||
],
|
||||
"output_dir": "outputs/prefix/t5-small/",
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import collections
|
|||
import copy
|
||||
|
||||
PATHBASE="/mnt/sfs_turbo/hsd/plm_cache/"
|
||||
PATHBASE="/home/hushengding/plm_cache/"
|
||||
# PATHBASE="/home/hushengding/plm_cache/"
|
||||
|
||||
AllConfigs = {}
|
||||
|
||||
|
@ -51,7 +51,9 @@ BaseConfigs['bert-base-cased'] = {
|
|||
"overwrite_output_dir": True,
|
||||
"push_to_hub": False,
|
||||
"push_to_delta_center": True,
|
||||
"save_strategy": "steps"
|
||||
"save_strategy": "steps",
|
||||
"datasets_load_from_disk": True,
|
||||
"datasets_saved_path": "/mnt/sfs_turbo/hsd/huggingface_datasets/saved_to_disk/"
|
||||
}
|
||||
|
||||
AllConfigs['prefix_bert-base-cased'] = copy.deepcopy(BaseConfigs['bert-base-cased'])
|
||||
|
@ -74,6 +76,13 @@ AllConfigs['soft_prompt_bert-base-cased'].update({
|
|||
"output_dir": "outputs/soft_prompt/bert-base-cased/",
|
||||
})
|
||||
|
||||
AllConfigs['prefix_bert-large-cased'] = copy.deepcopy(AllConfigs['prefix_bert-base-cased'])
|
||||
AllConfigs['prefix_bert-large-cased'].update({
|
||||
"output_dir": "outputs/prefix/bert-large-cased/",
|
||||
"model_name_or_path": f"{PATHBASE}bert-large-cased",
|
||||
"tokenizer_name": f"{PATHBASE}bert-large-cased",
|
||||
})
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import json
|
||||
|
|
|
@ -2,7 +2,7 @@ import collections
|
|||
import copy
|
||||
|
||||
PATHBASE="/mnt/sfs_turbo/hsd/plm_cache/"
|
||||
PATHBASE="/home/hushengding/plm_cache/"
|
||||
# PATHBASE="/home/hushengding/plm_cache/"
|
||||
|
||||
AllConfigs = {}
|
||||
|
||||
|
@ -47,7 +47,9 @@ BaseConfigs['t5-base'] = {
|
|||
"overwrite_output_dir": True,
|
||||
"push_to_hub": False,
|
||||
"push_to_delta_center": True,
|
||||
"save_strategy": "steps"
|
||||
"save_strategy": "steps",
|
||||
"datasets_load_from_disk": True,
|
||||
"datasets_saved_path": "/mnt/sfs_turbo/hsd/huggingface_datasets/saved_to_disk/"
|
||||
}
|
||||
|
||||
AllConfigs['bitfit_t5-base'] = copy.deepcopy(BaseConfigs['t5-base'])
|
||||
|
|
|
@ -1,2 +1,4 @@
|
|||
optuna
|
||||
sklearn
|
||||
sklearn
|
||||
openpromptu
|
||||
tensorboard
|
|
@ -1,16 +1,12 @@
|
|||
from functools import partial
|
||||
from random import random
|
||||
|
||||
from typing import Optional, Union
|
||||
from opendelta.utils.signature import get_arg_names_inside_func
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.utils.cuda import get_device
|
||||
from opendelta.basemodel import DeltaBase
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
from opendelta.delta_models.layers.activations import Activations
|
||||
import inspect
|
||||
from opendelta import BaseDeltaConfig
|
||||
import opendelta.utils.logging as logging
|
||||
import numpy as np
|
||||
|
|
|
@ -2,15 +2,11 @@ from typing import Optional, Union
|
|||
from opendelta.utils.signature import get_arg_names_inside_func
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.basemodel import DeltaBase, is_leaf_module
|
||||
from transformers.models.t5 import T5ForConditionalGeneration
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
import torch
|
||||
from torch.nn import init
|
||||
import math
|
||||
from opendelta.utils.structure_mapping import transform
|
||||
from opendelta import BaseDeltaConfig
|
||||
import opendelta.utils.logging as logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -5,11 +5,8 @@ from opendelta.utils.signature import get_arg_names_inside_func
|
|||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.utils.cuda import get_device
|
||||
from opendelta.basemodel import DeltaBase
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
import opendelta
|
||||
from opendelta.delta_models.layers.activations import Activations
|
||||
import inspect
|
||||
from opendelta.delta_models.layers.hypercomplex_linear import PHMLinear
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from turtle import forward
|
||||
from typing import Optional, Union
|
||||
|
||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.basemodel import DeltaBase
|
||||
from transformers.models.t5 import T5ForConditionalGeneration
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
from opendelta import BaseDeltaConfig
|
||||
import math
|
||||
|
|
|
@ -7,12 +7,10 @@ from typing import Optional, Union
|
|||
from opendelta.utils.signature import get_arg_names_inside_func
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.utils.cuda import get_device
|
||||
from opendelta.basemodel import DeltaBase
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import math
|
||||
|
|
|
@ -4,7 +4,7 @@ from opendelta.utils.signature import get_arg_names_inside_func, signature
|
|||
from typing import Optional, Union
|
||||
from transformers.models.distilbert.modeling_distilbert import MultiHeadSelfAttention
|
||||
from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention
|
||||
from transformers.models.bert.modeling_bert import BertSelfAttention
|
||||
from transformers.models.bert.modeling_bert import BertAttention
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
|
||||
from transformers.models.bart.modeling_bart import BartAttention
|
||||
from transformers.models.roberta.modeling_roberta import RobertaAttention
|
||||
|
@ -12,7 +12,6 @@ from opendelta.utils.name_based_addressing import *
|
|||
from opendelta.utils.cuda import get_device
|
||||
from opendelta.basemodel import DeltaBase
|
||||
from transformers.models.t5 import T5ForConditionalGeneration
|
||||
import loralib as lora
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import opendelta.utils.logging as logging
|
||||
|
@ -266,6 +265,64 @@ class PrefixLayerDistilBert(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class PrefixLayerBert(nn.Module):
|
||||
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
||||
into the original attention layer's forward function.
|
||||
"""
|
||||
def __init__(self, prefix_token_num, num_heads, device,):
|
||||
super().__init__()
|
||||
self.prefix_token_num = prefix_token_num
|
||||
self.num_heads = num_heads
|
||||
self.device = device
|
||||
self.instantiated = False
|
||||
|
||||
def instantiate(self, hidden_dim):
|
||||
self.past_key = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
||||
self.past_value = nn.Parameter(torch.randn(self.prefix_token_num, hidden_dim, device=self.device), requires_grad=True)
|
||||
self.past_key_reparam = None
|
||||
self.past_value_reparam = None
|
||||
self.instantiated = True
|
||||
|
||||
|
||||
def pre_forward(self, *args, **kwargs):
|
||||
r"""The args and kwargs are inherited from the T5Attention's forward function.
|
||||
"""
|
||||
batch_size = args[0].shape[0]
|
||||
if not self.instantiated:
|
||||
self.hidden_dim = args[0].shape[-1]
|
||||
self.instantiate(hidden_dim=self.hidden_dim)
|
||||
if self.past_key_reparam is None:
|
||||
past_key = self.past_key.data
|
||||
else:
|
||||
past_key = self.past_key_reparam
|
||||
if self.past_value_reparam is None:
|
||||
past_value = self.past_value.data
|
||||
else:
|
||||
past_value = self.past_value_reparam
|
||||
|
||||
|
||||
def expand_batchsize(x):
|
||||
x = x.reshape(self.prefix_token_num, self.num_heads, -1).transpose(0,1)
|
||||
x = x.unsqueeze(0).expand(batch_size, *x.shape)
|
||||
return x
|
||||
# from IPython import embe
|
||||
|
||||
if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
|
||||
kwargs['past_key_value'] = (expand_batchsize(past_key), expand_batchsize(past_value))
|
||||
|
||||
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
|
||||
am = kwargs['attention_mask'] # Should check the format of the attention_mask when moving to a new plm.
|
||||
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
||||
elif len(args) >1: # attention mask is passed via positional argument
|
||||
am = args[1]
|
||||
am = torch.cat([-torch.zeros((*am.shape[:-1],self.prefix_token_num), dtype = am.dtype,device=am.device), am], dim=-1)
|
||||
args = (args[0], am) + args[2:]
|
||||
# from IPython import embed
|
||||
# embed(header = "Herein prefixroberta")
|
||||
return args, kwargs
|
||||
|
||||
|
||||
|
||||
class PrefixLayerRoberta(nn.Module):
|
||||
r"""A layer of prefix tuning module. The layer's forward function pass (or concatenate) the additional past_key_value
|
||||
into the original attention layer's forward function.
|
||||
|
@ -540,8 +597,9 @@ class PrefixModel(DeltaBase):
|
|||
prefixlayer = PrefixLayerDistilBert(prefix_token_num=self.prefix_token_num, device=module_device)
|
||||
self.insert_sequential_module(getattr(module, "k_lin"), pre_caller=prefixlayer.key_pre_forward, post_caller=prefixlayer.key_forward)
|
||||
self.insert_sequential_module(getattr(module, "v_lin"), pre_caller=prefixlayer.value_pre_forward, post_caller=prefixlayer.value_forward)
|
||||
elif isinstance(module, BertSelfAttention):
|
||||
raise NotImplementedError
|
||||
elif isinstance(module, BertAttention):
|
||||
module_device = get_device(module)
|
||||
prefixlayer = PrefixLayerBert(prefix_token_num=self.prefix_token_num, num_heads=module.self.num_attention_heads ,device=module_device)
|
||||
elif isinstance(module, RobertaAttention):
|
||||
module_device = get_device(module)
|
||||
prefixlayer = PrefixLayerRoberta(prefix_token_num=self.prefix_token_num, num_heads=module.self.num_attention_heads,device=module_device)
|
||||
|
|
|
@ -3,7 +3,7 @@ transformers>=4.10.0
|
|||
datasets==1.17.0
|
||||
sentencepiece>=0.1.96
|
||||
tqdm>=4.62.2
|
||||
loralib
|
||||
# loralib
|
||||
decorator
|
||||
rich
|
||||
web.py
|
||||
|
|
Loading…
Reference in New Issue