example_prompt_unify

This commit is contained in:
shengdinghu 2022-04-18 23:28:13 +08:00
parent 151e3dd2c2
commit 0de3cbd31d
13 changed files with 814 additions and 566 deletions

BIN
dist/opendelta-0.0.4-py3-none-any.whl vendored Normal file

Binary file not shown.

BIN
dist/opendelta-0.0.4.tar.gz vendored Normal file

Binary file not shown.

View File

@ -1,3 +1,3 @@
from .tasks import TASK_MAPPING, AutoTask from .tasks import TASK_MAPPING, AutoTask
from .data_collator import TaskDataCollatorForSeq2Seq # from .data_collator import TaskDataCollatorForSeq2Seq
from .postprocessors import AutoPostProcessor # from .postprocessors import AutoPostProcessor

View File

@ -1,28 +1,28 @@
import numpy as np # import numpy as np
from dataclasses import dataclass # from dataclasses import dataclass
from transformers import DataCollatorForSeq2Seq # from transformers import DataCollatorForSeq2Seq
@dataclass # @dataclass
class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): # class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def check_uniqueness(self, samples):
assert len(np.unique(samples)) == 1
def __call__(self, features):
# tasks = [d.pop('task') for d in features]
# self.check_uniqueness(tasks)
output = super().__call__(features)
# output["task"] = tasks[0]
return output
# class CustomDataCollator(DefaultDataCollator):
# def check_uniqueness(self, samples): # def check_uniqueness(self, samples):
# assert len(np.unique(samples)) == 1 # assert len(np.unique(samples)) == 1
# def __call__(self, features): # def __call__(self, features):
# mask_positions = [d.pop('mask_positions') for d in features] # # tasks = [d.pop('task') for d in features]
# # self.check_uniqueness(tasks) # # self.check_uniqueness(tasks)
# output = super().__call__(features) # output = super().__call__(features)
# # output["task"] = tasks[0] # # output["task"] = tasks[0]
# return output # return output
# # class CustomDataCollator(DefaultDataCollator):
# # def check_uniqueness(self, samples):
# # assert len(np.unique(samples)) == 1
# # def __call__(self, features):
# # mask_positions = [d.pop('mask_positions') for d in features]
# # # self.check_uniqueness(tasks)
# # output = super().__call__(features)
# # # output["task"] = tasks[0]
# # return output

View File

@ -15,6 +15,7 @@ import re
from openprompt.prompts import ManualTemplate, ManualVerbalizer from openprompt.prompts import ManualTemplate, ManualVerbalizer
from openprompt.plms.utils import TokenizerWrapper from openprompt.plms.utils import TokenizerWrapper
from openprompt.data_utils import InputExample from openprompt.data_utils import InputExample
from openprompt.prompts import GenerationVerbalizer
import itertools import itertools
@ -28,184 +29,170 @@ from typing import List, Dict
from collections import defaultdict from collections import defaultdict
from openprompt.utils import round_list from openprompt.utils import round_list
import warnings import warnings
class MLMTokenizerWrapper:
def __init__(self, max_seq_length, tokenizer, truncate_method): # class MLMTokenizerWrapper:
self.max_seq_length=max_seq_length # def __init__(self, max_seq_length, tokenizer, truncate_method, mask_token_func=lambda i: "<mask>"):
self.tokenizer=tokenizer # self.max_seq_length=max_seq_length
self.num_special_tokens_to_add = len(tokenizer("")['input_ids']) # self.tokenizer=tokenizer
# from IPython import embed; embed(header="Truega") # self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
self.truncate_method=truncate_method # # from IPython import embed; embed(header="Truega")
self.total_passed_sentences = 0 # self.truncate_method=truncate_method
self.num_truncated_sentences = 0 # self.total_passed_sentences = 0
if truncate_method=='tail': # self.num_truncated_sentences = 0
self.truncate_fct = self.truncate_from_tail # self.mask_token_func = mask_token_func
elif truncate_method=='head':
self.truncate_fct = self.truncate_from_head # if truncate_method=='tail':
elif truncate_method == 'balanced': # self.truncate_fct = self.truncate_from_tail
self.truncate_fct = self.balanced_truncate # elif truncate_method=='head':
else: # self.truncate_fct = self.truncate_from_head
raise NotImplementedError # elif truncate_method == 'balanced':
# self.truncate_fct = self.balanced_truncate
# else:
# raise NotImplementedError
# def merge_wrapped_example(self, wrapped_example,):
# ''' # TODO doens't consider the situation that input has two parts
# '''
# wrapped_example
# # for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
# # the input. Or in generation tasks, we need to generate a piece of target_text.
# # In these case, it tokenized to the encoded_tgt_text for furture use.
# encoder_inputs = defaultdict(list)
# # from IPython import embed; embed(header="Line 67")
# mask_count = 0
# for piece in wrapped_example:
# if piece['text'] == "<mask>":
# encode_text = self.tokenizer.encode(self.mask_token_func(mask_count), add_special_tokens=False, return_special_tokens_mask=True )
# mask_count += 1
# else:
# encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
# encoder_inputs['input_ids'].append(encode_text)
# encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
def merge_wrapped_example(self, wrapped_example, ): # encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
''' # TODO doens't consider the situation that input has two parts # encoder_inputs.pop("shortenable_ids")
''' # encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
# decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
wrapped_example, others = wrapped_example # return decoded_inputs
# for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
# the input. Or in generation tasks, we need to generate a piece of target_text. # @staticmethod
# In these case, it tokenized to the encoded_tgt_text for furture use. # def balanced_truncate(input_dict: Dict,
# num_tokens_to_truncate: int=0) -> Dict:
# '''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
# '''
# shortenable_lens = [len(parts) if parts[0]==1 else 0
# for parts in input_dict['shortenable_ids']]
# total_shortenable_len = sum(shortenable_lens)
# num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
# for part_len in shortenable_lens]
# round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
# truncated_example = defaultdict(list)
# for key in input_dict:
# parts = input_dict[key]
# for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
# truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
# return truncated_example
# @staticmethod
# def truncate_from_tail(input_dict: Dict,
# num_tokens_to_truncate: int=0) -> Dict:
# r"""truncate the inputs from the rear
# """
# truncated_example = defaultdict(list)
# shortenable_ids = input_dict['shortenable_ids']
# for key in input_dict:
# parts = input_dict[key]
# to_trunc = num_tokens_to_truncate
# for i, part in enumerate(parts[::-1]):
# if len(part) == 0: # to prevent some part are empty after tokenization
# continue
# if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
# continue
# parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
# to_trunc -= len(part)
# if to_trunc <= 0:
# break
# truncated_example[key] = parts
# return truncated_example
# @staticmethod
# def truncate_from_head(input_dict: Dict,
# num_tokens_to_truncate: int=0) -> Dict:
# r"""truncate the inputs from the head
# """
# truncated_example = defaultdict(list)
# shortenable_ids = input_dict['shortenable_ids']
# for key in input_dict:
# parts = input_dict[key]
# to_trunc = num_tokens_to_truncate
# for i, part in enumerate(parts):
# if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
# continue
# parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
# to_trunc -= len(part)
# if to_trunc <= 0:
# break
# truncated_example[key] = parts
# return truncated_example
# @staticmethod
# def concate_parts(input_dict: Dict) -> Dict:
# for key in input_dict:
# input_dict[key] = list(itertools.chain(*input_dict[key]))
# return input_dict
# def truncate(self, encoder_inputs):
# total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
# num_specials = self.num_special_tokens_to_add
# # print("num_specials", num_specials)
# num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
# self.total_passed_sentences+=1
# if num_tokens_to_truncate>0:
# self.num_truncated_sentences += 1
# if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
# raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
# encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
# num_tokens_to_truncate=num_tokens_to_truncate)
# return encoder_inputs
# def tokenizer_preprocessor(self, example):
# # source, target = example
# # from IPython import embed; embed(header="Trehre2")
# label = example['label']
# guid = example['idx']
# meta = dict(example)
# meta.pop("label")
# meta.pop("idx")
encoder_inputs = defaultdict(list) # # from IPython import embed; embed(header="Trehre2")
for piece in wrapped_example:
encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
encoder_inputs['input_ids'].append(encode_text)
encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
# e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
encoder_inputs = self.truncate(encoder_inputs=encoder_inputs) # if self.predict_with_generate:
encoder_inputs.pop("shortenable_ids") # e = self.verbalizer.wrap_one_example(e)
encoder_inputs = self.concate_parts(input_dict=encoder_inputs) # example_wrapped = self.template.wrap_one_example(e)
decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False) # encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(example_wrapped)
# print(encoded_sentence)
# again_encode = self.tokenizer.encode(decoded_inputs, add_special_tokens=False, return_special_tokens_mask=True) # if self.predict_with_generate:
# if len(again_encode)> self.max_seq_length - 2: # # return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
# print("length exceed!") # return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
# print(wrapped_example) # else:
# print(encoder_inputs['input_ids']) # return {"source": encoded_sentence, "label": label, 'target': e.target_text, 'extra_fields':{'dataset_name':self.name}}
# print(again_encode)
# print(decoded_inputs)
# exit()
# delete shortenable ids
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
# encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs)
# # create special input ids
# encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids'])
# # padding
# encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id)
return decoded_inputs
@staticmethod
def balanced_truncate(input_dict: Dict,
num_tokens_to_truncate: int=0) -> Dict:
'''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
'''
shortenable_lens = [len(parts) if parts[0]==1 else 0
for parts in input_dict['shortenable_ids']]
total_shortenable_len = sum(shortenable_lens)
num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
for part_len in shortenable_lens]
round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
truncated_example = defaultdict(list)
for key in input_dict:
parts = input_dict[key]
for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
return truncated_example
@staticmethod
def truncate_from_tail(input_dict: Dict,
num_tokens_to_truncate: int=0) -> Dict:
r"""truncate the inputs from the rear
"""
truncated_example = defaultdict(list)
shortenable_ids = input_dict['shortenable_ids']
for key in input_dict:
parts = input_dict[key]
to_trunc = num_tokens_to_truncate
for i, part in enumerate(parts[::-1]):
if len(part) == 0: # to prevent some part are empty after tokenization
continue
if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
continue
parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
to_trunc -= len(part)
if to_trunc <= 0:
break
truncated_example[key] = parts
return truncated_example
@staticmethod
def truncate_from_head(input_dict: Dict,
num_tokens_to_truncate: int=0) -> Dict:
r"""truncate the inputs from the head
"""
truncated_example = defaultdict(list)
shortenable_ids = input_dict['shortenable_ids']
for key in input_dict:
parts = input_dict[key]
to_trunc = num_tokens_to_truncate
for i, part in enumerate(parts):
if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
continue
parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
to_trunc -= len(part)
if to_trunc <= 0:
break
truncated_example[key] = parts
return truncated_example
@staticmethod
def concate_parts(input_dict: Dict) -> Dict:
for key in input_dict:
input_dict[key] = list(itertools.chain(*input_dict[key]))
return input_dict
# @staticmethod
# def padding(input_dict: Dict,
# max_len: int, pad_id_for_inputs: int=0, pad_id_for_others: int=0) -> None:
# for key, value in input_dict.items():
# if (len(input_dict[key]) > max_len):
# raise ValueError(f'''
# Truncated seq length of '{key}' still greater than max length '{max_len}.'
# One possible reason is that no enough shortenable parts in template. Try add {{"shortenable": "True"}} property.
# ''')
# if 'input' in key:
# input_dict[key].extend([pad_id_for_inputs]*(max_len-len(value)))
# else:
# input_dict[key].extend([pad_id_for_others]*(max_len-len(value)))
# return input_dict
# def add_special_tokens(self, encoder_inputs):
# # add special tokens
# for key in encoder_inputs:
# if key == "input_ids":
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# encoder_inputs[key] = self.tokenizer.build_inputs_with_special_tokens(
# encoder_inputs[key])
# return encoder_inputs
def truncate(self, encoder_inputs):
total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
num_specials = self.num_special_tokens_to_add
# print("num_specials", num_specials)
num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
self.total_passed_sentences+=1
if num_tokens_to_truncate>0:
self.num_truncated_sentences += 1
if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
num_tokens_to_truncate=num_tokens_to_truncate)
return encoder_inputs
@ -234,46 +221,23 @@ class AbstractTask(abc.ABC):
"superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"] "superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"]
large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"] large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"]
def __init__(self, config, data_args, tokenizer, predict_with_generate, seed=42, default_max_length=1): def __init__(self, config, data_args, seed=42, default_max_length=1):
self.config = config self.config = config
self.seed = seed self.seed = seed
self.data_args = data_args self.data_args = data_args
self.tokenizer = tokenizer # self.tokenizer = tokenizer
self.predict_with_generate = predict_with_generate # self.predict_with_generate = predict_with_generate
self.default_max_length = default_max_length self.default_max_length = default_max_length
self.truncate_method = getattr(data_args, "truncate_method", "balanced")
tid = getattr(config, "template_id", 0) # generation_paradigm = getattr(config, "generation_paradigm", True)
vid = getattr(config, "verbalizer_id", 0)
self.template = ManualTemplate(tokenizer=self.tokenizer, text = self.templates_text[tid])
self.verbalizer = ManualVerbalizer(tokenizer=self.tokenizer, classes = self.labels_list, label_words=self.verbalizers[vid])
# if self.predict_with_generate:
# self.reverse_verbalizer = {(int(x) for x in self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))): label for label in self.labels_list}
# else:
# self.reverse_verbalizer = {int(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))[0]): label for label in self.labels_list}
self.tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=self.data_args.max_source_length, tokenizer=self.tokenizer, truncate_method=self.truncate_method)
generation_paradigm = getattr(config, "generation_paradigm", True)
# self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm) # self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm)
self.max_target_length = self.get_max_target_length(self.default_max_length)
def get_max_target_length(self, default_max_length):
if self.predict_with_generate:
return max([len(label) for key, label in self.verbalizer.label_words_ids.items()])
else:
return default_max_length
def seq2seq_format(self, source, target, extra_fields={} # def get_max_target_length(self, default_max_length):
): # if self.predict_with_generate:
# return -1
return {'source': ' '.join(source), # else:
'target': ' '.join(target), # return default_max_length
'task': self.name,
'extra_fields': extra_fields
}
def check_n_obs(self, n_obs, total_size): def check_n_obs(self, n_obs, total_size):
if n_obs is not None and n_obs > total_size: if n_obs is not None and n_obs > total_size:
@ -312,37 +276,9 @@ class AbstractTask(abc.ABC):
else: else:
return indices[validation_size:] return indices[validation_size:]
def map_dataset(self, dataset):
# from IPython import embed; embed(header="in get target length")
return dataset.map(self.preprocessor).map(self.tokenizer_preprocessor)
def preprocessor(self, example): def preprocessor(self, example):
return example return example
def tokenizer_preprocessor(self, example):
# source, target = example
# from IPython import embed; embed(header="Trehre2")
label = example['label']
guid = example['idx']
meta = dict(example)
meta.pop("label")
meta.pop("idx")
# from IPython import embed; embed(header="Trehre2")
e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
template_e = self.template.wrap_one_example(e)
encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(template_e)
if self.predict_with_generate:
# return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
raise NotImplementedError
else:
return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
def get(self, split, n_obs=None, split_validation_test=False): def get(self, split, n_obs=None, split_validation_test=False):
# For small datasets (n_samples < 10K) without test set, we divide validation set to # For small datasets (n_samples < 10K) without test set, we divide validation set to
# half, use one half as test set and one half as validation set. # half, use one half as test set and one half as validation set.
@ -368,7 +304,7 @@ class AbstractTask(abc.ABC):
# shuffles the data and samples it. # shuffles the data and samples it.
if n_obs is not None: if n_obs is not None:
dataset = self.subsample(dataset, n_obs) dataset = self.subsample(dataset, n_obs)
return self.map_dataset(dataset) return dataset.map(self.preprocessor)
class Squad(AbstractTask): class Squad(AbstractTask):
name = "squad" name = "squad"
@ -387,25 +323,7 @@ class Squad(AbstractTask):
return self.seq2seq_format(source, target, add_prefix) return self.seq2seq_format(source, target, add_prefix)
class MRPC(AbstractTask): ##GLUE
name = "mrpc"
labels_list = ["0", "1"]
metric = [metrics.f1_score, metrics.accuracy]
metric_names = ["f1", "accuracy"]
split_to_data_split = {"train": "train",
"validation": "validation",
"test": "validation"}
def load_dataset(self, split):
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
# def preprocessor(self, example, add_prefix=True):
# src_texts = ["sentence1:", example['sentence1'],
# "sentence2:", example["sentence2"]]
# tgt_texts = [str(example['label'])]
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
class COLA(AbstractTask): class COLA(AbstractTask):
name = "cola" name = "cola"
labels_list = ["0", "1"] labels_list = ["0", "1"]
@ -415,15 +333,20 @@ class COLA(AbstractTask):
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
templates_text = {"0": """sentence: {"meta": 'sentence', "shortenable":True} Are there any error in the sentence? {"mask"}""",
}
verbalizers = {
"0":{ "0": "yes", "1": "no"}
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.cola")[split]
else:
return datasets.load_dataset('glue', 'cola', return datasets.load_dataset('glue', 'cola',
split=split, script_version="master") split=split, script_version="master")
# def preprocessor(self, example, add_prefix=True):
# src_texts = ["sentence:", example['sentence']]
# tgt_texts = [str(example['label'])]
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
class SST2(AbstractTask): class SST2(AbstractTask):
name = "sst2" name = "sst2"
@ -434,38 +357,50 @@ class SST2(AbstractTask):
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
verbalizers = [ verbalizers = {
"0":{"0":"negative","1":"positive"}
}
] templates_text = {
"0":"""The sentiment of sentence: "{"meta":"sentence", "shortenable":True} is {"mask"}."""
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.sst2")[split]
else:
return datasets.load_dataset('glue', 'sst2', return datasets.load_dataset('glue', 'sst2',
split=split, script_version="master") split=split, script_version="master")
def preprocessor(self, example, add_prefix=True):
src_texts = ["sentence:", example['sentence']]
tgt_texts = [str(example['label'])]
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
class STSB(AbstractTask): class MRPC(AbstractTask):
name = "stsb" name = "mrpc"
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)] labels_list = ["0", "1"]
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef] metric = [metrics.f1_score, metrics.accuracy]
metric_names = ["pearson", "spearmanr"] metric_names = ["f1", "accuracy"]
split_to_data_split = {"train": "train", split_to_data_split = {"train": "train",
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
def load_dataset(self, split):
return datasets.load_dataset('glue', 'stsb',
split=split, script_version="master")
def preprocessor(self, example, add_prefix=True): templates_text = {
src_texts = ["sentence1:", example['sentence1'], "0": """sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2: {"meta":"sentence2", "shortenable":True}. Are sentence1 and sentence2 equivalent? {"mask"}.""",
"sentence2:", example["sentence2"]] }
tgt_texts = [str(round_stsb_target(example['label']))]
return self.seq2seq_format(src_texts, tgt_texts, add_prefix) verbalizers = {
"0":{"0": "no","1": "yes"}
}
def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mrpc")[split]
else:
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
class QQP(AbstractTask): class QQP(AbstractTask):
@ -477,14 +412,46 @@ class QQP(AbstractTask):
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
templates_text = {"0":
"""question1: {"meta": 'question1', "shortenable":True}. question2: {"meta": 'question2', "shortenable":True} Are question1 and question2 equivalent? {"mask"}."""
}
verbalizers = {
"0":{"0": "no","1": "yes"}
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qqp")[split]
else:
return datasets.load_dataset('glue', 'qqp', return datasets.load_dataset('glue', 'qqp',
split=split, script_version="master") split=split, script_version="master")
class STSB(AbstractTask):
name = "stsb"
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)]
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
metric_names = ["pearson", "spearmanr"]
split_to_data_split = {"train": "train",
"validation": "validation",
"test": "validation"}
verbalizers = {
""
}
def load_dataset(self, split):
return datasets.load_dataset('glue', 'stsb',
split=split, script_version="master")
def preprocessor(self, example, add_prefix=True): def preprocessor(self, example, add_prefix=True):
src_texts = ["question1:", example['question1'], src_texts = ["sentence1:", example['sentence1'],
"question2:", example["question2"]] "sentence2:", example["sentence2"]]
tgt_texts = [str(example['label'])] tgt_texts = [str(round_stsb_target(example['label']))]
return self.seq2seq_format(src_texts, tgt_texts, add_prefix) return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
@ -498,14 +465,29 @@ class MNLI(AbstractTask):
metric_names = ["accuracy"] metric_names = ["accuracy"]
templates_text = {
"0":"""premise: {"meta": 'premise', "shortenable":True}. hypothesis: {"meta": 'hypothesis', "shortenable":True} Does the premise entails the hypothesis? {"mask"}.""",
}
verbalizers = {
"0":{
"0": "yes",
"1": "neutral",
"2": "no",
}
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mnli")[split]
else:
return datasets.load_dataset('glue', 'mnli', split=split, script_version="master") return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")
def preprocessor(self, example, add_prefix=True): # def preprocessor(self, example, add_prefix=True):
src_texts = ["premise:", example['premise'], # src_texts = ["premise:", example['premise'],
"hypothesis", example["hypothesis"]] # "hypothesis", example["hypothesis"]]
tgt_texts = [str(example['label'])] # tgt_texts = [str(example['label'])]
return self.seq2seq_format(src_texts, tgt_texts, add_prefix) # return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
class QNLI(AbstractTask): class QNLI(AbstractTask):
@ -517,14 +499,33 @@ class QNLI(AbstractTask):
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
templates_text = {
"0": """premise: {"meta": 'sentence', "shortenable":True}. hypothesis: {"meta": 'question', "shortenable":True}"""+
"""Does the premise entails the hypothesis? {"mask"}.""",
}
verbalizers = {
"0":{
"0": "yes",
"1": "no",
}
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qnli")[split]
else:
return datasets.load_dataset('glue', 'qnli', split=split, script_version="master") return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
def preprocessor(self, example, add_prefix=True): # def load_dataset(self, split):
src_texts = ["question:", example['question'], # return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
"sentence:", example["sentence"]]
tgt_texts = [str(example['label'])] # def preprocessor(self, example, add_prefix=True):
return self.seq2seq_format(src_texts, tgt_texts, add_prefix) # src_texts = ["question:", example['question'],
# "sentence:", example["sentence"]]
# tgt_texts = [str(example['label'])]
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
#Tested #Tested
class RTE(AbstractTask): class RTE(AbstractTask):
@ -537,15 +538,15 @@ class RTE(AbstractTask):
"test": "validation"} "test": "validation"}
templates_text = [ templates_text = {
"""sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2:,"""+ "0": """sentence1: {"meta": 'sentence1', "shortenable":True} sentence2: {"meta":"sentence2", "shortenable":True} The answer was {"mask"}.""",
"""{"meta":"sentence2", "shortenable":True}. The answer was {"mask"}.""", }
]
verbalizers = [{ verbalizers = {
"0": "yes", "0":{"0": "yes",
"1": "no" "1": "no"
}] }
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk: if self.data_args.datasets_load_from_disk:
@ -555,38 +556,6 @@ class RTE(AbstractTask):
split=split, script_version="master") split=split, script_version="master")
#Tested
class SuperGLUEBoolQ(AbstractTask):
name="superglue-boolq"
labels_list = ['0', '1']
metric = [metrics.accuracy]
metric_names = ["accuracy"]
split_to_data_split = {"train": "train",
"validation": "validation",
"test": "validation"}
verbalizers = [
{
"0": "no",
"1": "yes"
},
]
mlmhead_verbalizers = {
"0": "no",
"1": "yes"
}
templates_text = [
"""hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
]
def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
else:
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
class WNLI(AbstractTask): class WNLI(AbstractTask):
name = "wnli" name = "wnli"
@ -597,13 +566,13 @@ class WNLI(AbstractTask):
"validation": "validation", "validation": "validation",
"test": "validation"} "test": "validation"}
verbalizers = [{ verbalizers = {
"0": "True", "0":{"0": "True",
"1": "False", "1": "False",
}] }
templates_text = [ }
"""{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}.""" templates_text = {"0": """{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}."""
] }
def load_dataset(self, split): def load_dataset(self, split):
@ -613,6 +582,34 @@ class WNLI(AbstractTask):
return datasets.load_dataset('glue', 'wnli', split=split, script_version="master") return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")
#SuperGLUE
class SuperGLUEBoolQ(AbstractTask):
name="superglue-boolq"
labels_list = ['0', '1']
metric = [metrics.accuracy]
metric_names = ["accuracy"]
split_to_data_split = {"train": "train",
"validation": "validation",
"test": "validation"}
verbalizers = {
"0": {
"0": "no",
"1": "yes"
},
}
templates_text = {
"0": """hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
}
def load_dataset(self, split):
if self.data_args.datasets_load_from_disk:
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
else:
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
# #
class SuperGLUECB(AbstractTask): class SuperGLUECB(AbstractTask):
name = "superglue-cb" name = "superglue-cb"
@ -623,14 +620,15 @@ class SuperGLUECB(AbstractTask):
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy] metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
metric_names = ["f1_multiclass", "accuracy"] metric_names = ["f1_multiclass", "accuracy"]
verbalizers = [{ verbalizers = {
"0": "yes", "0":{"0": "yes",
"1": "no", "1": "no",
"2": "maybe" "2": "maybe"
}] }
templates_text = [ }
"""hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}.""" templates_text = {
] "0": """hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}."""
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk: if self.data_args.datasets_load_from_disk:
@ -648,13 +646,15 @@ class SuperGLUECOPA(AbstractTask):
metric = [metrics.accuracy] metric = [metrics.accuracy]
metric_names = ["accuracy"] metric_names = ["accuracy"]
verbalizers = [{ verbalizers = {
"0":{
"0": "1", "0": "1",
"1": "2", "1": "2",
}] }
templates_text = [ }
"""choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}.""" templates_text = {
] "0": """choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}."""
}
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk: if self.data_args.datasets_load_from_disk:
@ -673,19 +673,16 @@ class SuperGLUEMultiRC(AbstractTask):
metrics.accuracy] metrics.accuracy]
metric_names = ["f1", "em"] metric_names = ["f1", "em"]
# generation_verbalizers = [{
# "0": "no",
# "1": "yes",
# },
# ]
verbalizers = [{ verbalizers = {
"0": {
"0": "no", "0": "no",
"1": "yes", "1": "yes",
}] }
templates_text = [ }
"""question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}.""" templates_text = {
] "0": """question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}."""
}
def load_dataset(self, split): def load_dataset(self, split):
@ -720,15 +717,16 @@ class SuperGLUEWIC(AbstractTask):
metric = [metrics.accuracy] metric = [metrics.accuracy]
metric_names = ["accuracy"] metric_names = ["accuracy"]
verbalizers = [{ verbalizers = {
"0": {
"0": "No", "0": "No",
"1": "Yes", "1": "Yes",
}] }
}
templates_text = [ templates_text = {
"""sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}. "0": """sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}."""
""" }
]
def load_dataset(self, split): def load_dataset(self, split):
if self.data_args.datasets_load_from_disk: if self.data_args.datasets_load_from_disk:
@ -786,7 +784,7 @@ class SuperGLUEWIC(AbstractTask):
# text = self._mark_span(text, example['span2_text'], span2_index, '#') # text = self._mark_span(text, example['span2_text'], span2_index, '#')
# src_texts = ["text:", text] # src_texts = ["text:", text]
# tgt_texts = [str(example["label"])] # tgt_texts = [str(example["label"])]
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix) # return self.fseq2seq_format(src_texts, tgt_texts, add_prefix)
class SuperGLUERecord(AbstractTask): class SuperGLUERecord(AbstractTask):
@ -875,9 +873,9 @@ TASK_MAPPING = OrderedDict(
class AutoTask: class AutoTask:
@classmethod @classmethod
def get(self, task, config, data_args, tokenizer,predict_with_generate, seed=42): def get(self, task, config, data_args, seed=42):
if task in TASK_MAPPING: if task in TASK_MAPPING:
return TASK_MAPPING[task](config, data_args, tokenizer,predict_with_generate, seed) return TASK_MAPPING[task](config, data_args, seed)
raise ValueError( raise ValueError(
"Unrecognized task {} for AutoTask Model: {}.\n" "Unrecognized task {} for AutoTask Model: {}.\n"
"Task name should be one of {}.".format( "Task name should be one of {}.".format(

View File

@ -34,6 +34,7 @@ from transformers import (
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
DataCollatorForSeq2Seq,
HfArgumentParser, HfArgumentParser,
MBartTokenizer, MBartTokenizer,
default_data_collator, default_data_collator,
@ -41,8 +42,8 @@ from transformers import (
) )
from transformers.trainer_utils import is_main_process, get_last_checkpoint from transformers.trainer_utils import is_main_process, get_last_checkpoint
# from ..seq2seq.utils import get_adapter_config # from ..seq2seq.utils import get_adapter_config
from examples_prompt.data_processors import AutoTask, TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator from examples_prompt.data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
from examples_prompt.seq2seq_trainer import Seq2SeqTrainer from transformers import Seq2SeqTrainer
# from training_args import AdapterTrainingArguments # from training_args import AdapterTrainingArguments
from examples_prompt.trainers.trainer_utils import save_training_config from examples_prompt.trainers.trainer_utils import save_training_config
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -56,7 +57,8 @@ import json
import numpy as np import numpy as np
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"], TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
"cola": ['matthews_correlation'], "cola": ['matthews_correlation'],
@ -109,7 +111,6 @@ class RemainArgHfArgumentParser(HfArgumentParser):
def main(): def main():
# See all possible arguments in src/transformers/training_args.py # See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script. # or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns. # We now keep distinct sets of args, for a cleaner separation of concerns.
@ -202,6 +203,16 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
if training_args.predict_with_generate:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
else:
model = AutoModelForMaskedLM.from_pretrained( model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
@ -212,11 +223,6 @@ def main():
) )
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
from openprompt.prompts import ManualTemplate
from openprompt.plms import MLMTokenizerWrapper
template = ManualTemplate(tokenizer, text="""sentence1: {"meta": 'premise'}, sentence2:,"""+
"""{"meta":"hypothesis", "shortenable":True}, The answer was {"mask"} .""")
tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method='tail')
@ -248,49 +254,185 @@ def main():
# Temporarily set max_target_length for training. # Temporarily set max_target_length for training.
#max_target_length = data_args.max_target_length #max_target_length = data_args.max_target_length
padding = "max_length" if data_args.pad_to_max_length else False
def preprocess_function(examples, **kwargs):
# max_target_length += 1
tokenizer = kwargs['tokenizer']
data_args = kwargs['data_args']
print("max_length", data_args.max_source_length)
model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
padding="max_length", truncation=True)
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
# model_inputs["mask_position"] = mask_position
model_inputs["extra_fields"] = examples['extra_fields']
# from IPython import embed; embed(header="Therer")
return model_inputs
column_names = ['source', 'target', 'label', 'extra_fields'] column_names = ['source', 'target', 'label', 'extra_fields']
performance_metrics = {} performance_metrics = {}
def get_prompts(task, tokenizer, predict_with_generate, template_id="0", verbalizer_id="0"):
# tid = getattr(config, "template_id", "0")
# vid = getattr(config, "verbalizer_id", "0")
from openpromptu.prompts import GenerationVerbalizer, ManualVerbalizer
from openpromptu.prompts import ManualTemplate
template = ManualTemplate(text = task.templates_text[template_id])
if predict_with_generate:
verbalizer = GenerationVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
else:
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
# max_target_length = self.get_max_target_length(self.default_max_length)
from openpromptu import TokenizerWrapper
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
return template, verbalizer, tokenizer_wrapper
from openpromptu.data_utils import InputExample
max_target_length = 32
if os.path.basename(model_args.model_name_or_path).startswith("t5"):
mask_token_func = lambda i: tokenizer.additional_special_tokens[i]
def preprocess_function(raw_example, **kwargs):
# max_target_length += 1
tokenizer = kwargs['tokenizer']
data_args = kwargs['data_args']
template = kwargs['template']
verbalizer = kwargs['verbalizer']
tokenizer_wrapper = kwargs['tokenizer_wrapper']
split = kwargs['split']
# extra_fileds = example['extra_fields']
example = InputExample(**raw_example)
# from collections import namedtuple
# example['tgt_text'] = ""
# example = namedtuple("ObjectName", example.keys())(*example.values())
try:
example = verbalizer.wrap_one_example(example)
example, other = template.wrap_one_example(example)
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
model_inputs = tokenizer(input_sentence, max_length=256,
padding="max_length", truncation=True)
except:
from IPython import embed; embed(header="Therer")
# if split == "train":
with tokenizer.as_target_tokenizer():
label = tokenizer(other['tgt_text']).input_ids
# label = [l if l != tokenizer.pad_token_id else -100 for l in label]
# from IPython import embed; embed(header="Therer")
model_inputs["labels"] = label
# else:
# # from IPython import embed; embed(header="Therer")
# model_inputs["tgt_text"] = other['tgt_text']
# model_inputs['labels'] = None # model_inputs["extra_fields"] = extra_fileds
# from IPython import embed; embed(header="Therer2")
return model_inputs
def compute_metrics(eval_preds, tokenizer, dataset_name, eval_metric):
# from IPython import embed; embed(header="In compute metrics")
preds, labels = eval_preds
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# post_processor = .get(data_args.dataset_name[0], tokenizer,
# data_args.ignore_pad_token_for_loss)
# decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
result = {}
for metric in eval_metric:
result.update(metric(decoded_preds, decoded_labels))
average_metric = sum(result.values())/len(result)
result.update({"average_metrics":average_metric})
return result
elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
or os.path.basename(model_args.model_name_or_path).startswith("bert"):
mask_token_func = lambda i: tokenizer.mask_token
def preprocess_function(raw_example, **kwargs):
# max_target_length += 1
# from IPython import embed; embed(header="Therer")
tokenizer = kwargs['tokenizer']
data_args = kwargs['data_args']
template = kwargs['template']
verbalizer = kwargs['verbalizer']
tokenizer_wrapper = kwargs['tokenizer_wrapper']
example = InputExample(**raw_example)
# from collections import namedtuple
# example['tgt_text'] = ""
# example = namedtuple("ObjectName", example.keys())(*example.values())
# try:
# example = verbalizer.wrap_one_example(example)
example, other = template.wrap_one_example(example)
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
model_inputs = tokenizer(input_sentence, max_length=256,
padding="max_length", truncation=True)
# print("max_length", data_args.max_source_length)
# model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
# padding="max_length", truncation=True)
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
# model_inputs["mask_position"] = mask_position
# model_inputs["extra_fields"] = examples['extra_fields']
# from IPython import embed; embed(header="Therer")
return model_inputs
def compute_metrics(eval_preds, dataset_name):
# from IPython import embed; embed(header="In compute metrics")
preds, labels = eval_preds.predictions, eval_preds.label_ids
preds = np.argmax(preds, axis=-1)
result = {}
average_metrics = []
for metric in eval_metric:
metric_item = metric(preds, labels)
metric_value = list(metric_item.values())
result.update(metric_item)
average_metrics.extend(metric_value)
print("average:",average_metrics)
average_metric = sum(average_metrics)/len(average_metrics)
result.update({"average_metrics":average_metric})
return result
if training_args.do_train: if training_args.do_train:
train_task = AutoTask.get(data_args.task_name, train_task = AutoTask.get(data_args.task_name,
data_args.dataset_config_name, data_args.dataset_config_name,
data_args=data_args, data_args=data_args,
tokenizer=tokenizer, # tokenizer=tokenizer,
predict_with_generate=training_args.predict_with_generate, # predict_with_generate=training_args.predict_with_generate,
seed=data_args.data_seed) seed=data_args.data_seed)
train_dataset = train_task.get(split='train', train_dataset = train_task.get(split='train',
split_validation_test=training_args.split_validation_test, split_validation_test=training_args.split_validation_test,
n_obs=data_args.max_train_samples) n_obs=data_args.max_train_samples)
template, verbalizer, tokenizer_wrapper = get_prompts(train_task, tokenizer, training_args.predict_with_generate)
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
functools.partial(preprocess_function, functools.partial(preprocess_function,
data_args=data_args, data_args=data_args,
tokenizer=tokenizer), tokenizer=tokenizer,
batched=True, template=template,
verbalizer=verbalizer,
tokenizer_wrapper=tokenizer_wrapper,
split="train"),
batched=False,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"], remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
@ -298,6 +440,8 @@ def main():
eval_splits_names = [] eval_splits_names = []
if training_args.do_eval: if training_args.do_eval:
@ -306,87 +450,49 @@ def main():
eval_splits_names.append("test") eval_splits_names.append("test")
eval_splits = {} eval_splits = {}
for split_name in eval_splits_names: for split_name in eval_splits_names:
_tasks = {dataset_name: AutoTask.get(dataset_name, eval_task = AutoTask.get(data_args.task_name,
dataset_config_name, data_args.dataset_config_name,
data_args=data_args, data_args=data_args,
tokenizer=tokenizer, # tokenizer=tokenizer,
predict_with_generate=training_args.predict_with_generate, # predict_with_generate=training_args.predict_with_generate,
seed=data_args.data_seed) seed=data_args.data_seed)
for dataset_name, dataset_config_name\ # for dataset_name, dataset_config_name\
in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))} # in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
_datasets = {dataset_name: task.get(split=split_name, eval_dataset = eval_task.get(split=split_name,
split_validation_test=training_args.split_validation_test, split_validation_test=training_args.split_validation_test,
n_obs=data_args.max_train_samples) n_obs=data_args.max_train_samples)
for dataset_name, task in _tasks.items()
}
_datasets = {dataset_name: d.map(
template, _verbalizer, tokenizer_wrapper = get_prompts(eval_task, tokenizer, training_args.predict_with_generate)
eval_dataset = eval_dataset.map(
functools.partial(preprocess_function, functools.partial(preprocess_function,
data_args=data_args, data_args=data_args,
tokenizer=tokenizer), tokenizer=tokenizer,
batched=True, template=template,
verbalizer=_verbalizer,
tokenizer_wrapper=tokenizer_wrapper,
split=split_name),
batched=False,
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
remove_columns=[x for x in d.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"], remove_columns=[x for x in eval_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
for dataset_name, d in _datasets.items()
}
eval_splits[split_name] = _datasets
eval_splits[split_name] = eval_dataset
if split_name == "test": if split_name == "test":
eval_metrics = {dataset_name:task.metric for dataset_name, task in _tasks.items()} eval_metric = eval_task.metric
verbalizers = {dataset_name:task.verbalizer for dataset_name, task in _tasks.items()} verbalizer = _verbalizer
# Metric, we assume we have only one training task.
# eval_metrics = [task.metric for task in
# for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)][0]
# Extracts the extra information needed to evaluate on each dataset.
# These information are only used in the compute_metrics.
# We will assume that the test/eval dataloader does not change the order of
# the data.
# data_info = {"eval": eval_datasets[data_args.eval_dataset_name[0]]['extra_fields'],
# "test": test_datasets[data_args.test_dataset_name[0]]['extra_fields'],
# "train": train_dataset['extra_fields']}
def compute_metrics(eval_preds, dataset_name):
preds, labels = eval_preds.predictions, eval_preds.label_ids
preds = np.argmax(preds, axis=-1)
result = {}
average_metrics = []
for metric in eval_metrics[dataset_name]:
metric_item = metric(preds, labels)
metric_value = list(metric_item.values())
result.update(metric_item)
average_metrics.extend(metric_value)
print("average:",average_metrics)
average_metric = sum(average_metrics)/len(average_metrics)
result.update({"average_metrics":average_metric})
return result
# from IPython import embed; embed(header="isseq2seq")
# Initialize our Trainer
# if training_args.is_seq2seq == True:
# trainer = Seq2SeqTrainer(
# model=model,
# args=training_args,
# delta_args=delta_args,
# train_dataset=splits['train'] if training_args.do_train else None,
# eval_dataset=list(splits['validation'].values())[0] if training_args.do_eval else None,
# data_info = data_info,
# tokenizer=tokenizer,
# data_collator=data_collator,
# compute_metrics=compute_metrics if training_args.predict_with_generate else None,
# evaluation_metrics = TASK_TO_METRICS[data_args.dataset_name[0]],
# )
# else:
class MLMTrainer(Trainer): class MLMTrainer(Trainer):
_verbalizers = verbalizers def __init__(self, verbalizer=None, **kwargs):
super().__init__(**kwargs)
self.verbalizer=verbalizer
# def training_step(self, model, inputs): # def training_step(self, model, inputs):
# from IPython import embed; embed(header="in trainstep") # from IPython import embed; embed(header="in trainstep")
@ -394,7 +500,7 @@ def main():
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop('labels') labels = inputs.pop('labels')
extra_fields = inputs.pop("extra_fields") # extra_fields = inputs.pop("extra_fields")
outputs = model(**inputs) outputs = model(**inputs)
logits = outputs.get("logits") logits = outputs.get("logits")
input_ids = inputs['input_ids'] input_ids = inputs['input_ids']
@ -402,20 +508,8 @@ def main():
# from IPython import embed; embed(header="382") # from IPython import embed; embed(header="382")
verbalizer = self._verbalizers[extra_fields[0]['dataset_name']].cuda() verbalizer = self.verbalizer.cuda()
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)] logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
# colidx = torch.where(input_ids == verbalizer.tokenizer.mask_token_id)[0].cpu()
# print(colidx)
# missing = set([i for i in range(input_ids.size(0))]) - set(colidx.numpy())
# print(missing)
# if len(missing) > 0:
# print("missing")
# missing = list(missing)[0]
# input_ids_missing = input_ids[missing]
# print(input_ids_missing)
# missing_tokens = verbalizer.tokenizer.convert_ids_to_tokens(input_ids_missing)
# print(missing_tokens)
label_logits = verbalizer.process_logits(logits_at_mask) label_logits = verbalizer.process_logits(logits_at_mask)
loss_fct = torch.nn.CrossEntropyLoss() loss_fct = torch.nn.CrossEntropyLoss()
# from IPython import embed; embed(header="In compute loss") # from IPython import embed; embed(header="In compute loss")
@ -423,21 +517,136 @@ def main():
outputs.logits = label_logits outputs.logits = label_logits
return (loss, outputs) if return_outputs else loss return (loss, outputs) if return_outputs else loss
class MySeq2SeqTrainer(Seq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=False):
# from IPython import embed; embed(header="agag")
intlabel = inputs.pop('label')
# extra_fields = inputs.pop("extra_fields")
outputs = model(**inputs)
# logits = outputs.get("logits")
# input_ids = inputs['input_ids']
# # from IPython import embed; embed(header="382")
# verbalizer = self._verbalizers.cuda()
# logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
# label_logits = verbalizer.process_logits(logits_at_mask)
# loss_fct = torch.nn.CrossEntropyLoss()
# # from IPython import embed; embed(header="In compute loss")
# loss = loss_fct(label_logits, labels)
# outputs.logits = label_logits
if return_outputs:
return (outputs.loss, outputs)
else:
return outputs.loss
# def evaluate(
# self,
# eval_dataset: Optional[Dict[str, Dataset]] = None,
# ignore_keys: Optional[List[str]] = None,
# metric_key_prefix: str = "eval",
# max_length: Optional[int] = None,
# num_beams: Optional[int] = None,
# ) -> Dict[str, float]:
# # TODO: this also needs to be set per dataset
# self._max_length = max_length
# self._num_beams = num_beams
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step(
self,
model, #nn.Module,
inputs, #Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only, #: bool,
ignore_keys, #: Optional[List[str]] = None,
): #-> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
Subclass and override to inject custom behavior.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
if not self.args.predict_with_generate or prediction_loss_only:
return super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
intlabel = inputs.pop('label')
gen_kwargs = {
"max_length": 10, # self._max_length if s is not None else self.model.config.max_length,
"num_beams": 1 #self._num_beams if self._num_beams is not None else self.model.config.num_beams,
}
generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,
)
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
else:
loss = None
if self.args.prediction_loss_only:
return (loss, None, None)
labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
# from IPython import embed; embed(header="In seqseqtrainer")
return (loss, generated_tokens, labels)
# def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys): # def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
# aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
# # from IPython import embed; embed() # # from IPython import embed; embed()
# return aa # return aa
from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin # from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
class DataCollatorWithExtraFields(DataCollatorMixin): # class DataCollatorWithExtraFields(DataCollatorMixin):
return_tensors: str = "pt" # return_tensors: str = "pt"
def torch_call(self, features): # def torch_call(self, features):
print(len(features)) # # print(len(features))
extra_fields = [f.pop('extra_fields') for f in features] # # extra_fields = [f.pop('extra_fields') for f in features]
batch = torch_default_data_collator(features) # batch = torch_default_data_collator(features)
batch['extra_fields'] =extra_fields # batch['extra_fields'] =extra_fields
print(batch['input_ids'].size()) # # print(batch['input_ids'].size())
print(batch['labels'].size()) # # print(batch['labels'].size())
return batch # return batch
# from transformers.data.data_collator import DefaultDataCollator # from transformers.data.data_collator import DefaultDataCollator
@ -455,15 +664,29 @@ def main():
training_args.remove_unused_columns = False training_args.remove_unused_columns = False
if os.path.basename(model_args.model_name_or_path).startswith("roberta") or \
os.path.basename(model_args.model_name_or_path).startswith("bert"):
trainer = MLMTrainer( trainer = MLMTrainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=train_dataset if training_args.do_train else None, train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_splits['eval'][data_args.task_name] if training_args.do_eval else None, eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name), compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name),
# tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=DataCollatorWithExtraFields(), # data_collator=DataCollatorWithExtraFields(),
verbalizer=verbalizer,
) )
elif os.path.basename(model_args.model_name_or_path).startswith("t5"):
trainer = MySeq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer, dataset_name=data_args.task_name, eval_metric=eval_metric),
tokenizer=tokenizer,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
)
# Saves training config. # Saves training config.
@ -522,24 +745,23 @@ def main():
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
for dataset_name, eval_dataset in eval_splits['eval'].items():
metrics = trainer.evaluate(eval_dataset=eval_dataset, metrics = trainer.evaluate(eval_dataset=eval_splits['eval'],
) )
trainer.log_metrics(f"{dataset_name}_eval", metrics) trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
trainer.save_metrics(f"{dataset_name}_eval", metrics) trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
all_results['evaluate'][dataset_name] = metrics all_results['evaluate'][data_args.task_name] = metrics
# Test # Test
all_results['test'] = {} all_results['test'] = {}
if training_args.do_test: if training_args.do_test:
logger.info("*** Test ***") logger.info("*** Test ***")
for dataset_name, test_dataset in eval_splits['test'].items(): metrics = trainer.evaluate(eval_dataset=eval_splits['test'],
metrics = trainer.evaluate(eval_dataset=test_dataset,
metric_key_prefix="test" metric_key_prefix="test"
) )
trainer.log_metrics(f"{dataset_name}_test", metrics) trainer.log_metrics(f"{data_args.task_name}_test", metrics)
trainer.save_metrics(f"{dataset_name}_test", metrics) trainer.save_metrics(f"{data_args.task_name}_test", metrics)
all_results['test'][dataset_name] = metrics all_results['test'][data_args.task_name] = metrics
# repo_name = create_hub_repo_name(root="DeltaHub", # repo_name = create_hub_repo_name(root="DeltaHub",
# dataset=data_args.task_name, # dataset=data_args.task_name,

View File

@ -1,7 +1,7 @@
cd configs/
python config_gen.py --job $3 python configs/config_gen.py --job $3
echo "Regenerate config" echo "Regenerate config"
cd ../
files=(cola mnli mrpc qnli qqp rte sst2 stsb superglue-boolq superglue-cb superglue-copa superglue-multirc superglue-record superglue-wic superglue-wsc.fixed) files=(cola mnli mrpc qnli qqp rte sst2 stsb superglue-boolq superglue-cb superglue-copa superglue-multirc superglue-record superglue-wic superglue-wsc.fixed)
for ((i=$1; i<=$2; i++)) for ((i=$1; i<=$2; i++))
do do

View File

@ -1,5 +1,5 @@
__version__ = "0.0.1" __version__ = "0.0.4"
class GlobalSetting: class GlobalSetting:
def __init__(self): def __init__(self):

View File

@ -728,6 +728,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
elif _delta_info['method'] == "insert_sequential": elif _delta_info['method'] == "insert_sequential":
self.insert_sequential_module(module=submodule, self.insert_sequential_module(module=submodule,
_delta_info=_delta_info) _delta_info=_delta_info)
elif _delta_info['method'] == "insert_parallel":
self.insert_parallel_module(module=submodule,
_delta_info=_delta_info)
else: else:
raise NotImplementedError raise NotImplementedError
@ -765,7 +768,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
submodule.forward = submodule.forward.__wrapped__ submodule.forward = submodule.forward.__wrapped__
delattr(submodule, _delta_info["delta_name"]) delattr(submodule, _delta_info["delta_name"])
else: else:
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It'ss not a wrapped function.".format(name)) raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
elif _delta_info['method'] == "insert_parallel":
if hasattr(submodule.forward, "__wrapped__"):
submodule.forward = submodule.forward.__wrapped__
delattr(submodule, _delta_info["delta_name"])
else:
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
else: else:
raise NotImplementedError raise NotImplementedError
@ -776,4 +785,3 @@ class DeltaBase(nn.Module, SaveLoadMixin):
except AttributeError: except AttributeError:
pass pass

View File

@ -28,7 +28,6 @@ class LowRankLinear(nn.Module):
self.r = r self.r = r
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
self.lora_dropout = lora_dropout self.lora_dropout = lora_dropout
self.lin = nn.Linear(in_features, out_features) #
if lora_dropout > 0.: if lora_dropout > 0.:
self.lora_dropout = nn.Dropout(p=lora_dropout) self.lora_dropout = nn.Dropout(p=lora_dropout)
else: else:
@ -37,7 +36,6 @@ class LowRankLinear(nn.Module):
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features))) self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r))) self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r self.scaling = self.lora_alpha / self.r
self.lin.reset_parameters() #
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)

View File

@ -1,3 +1,4 @@
from examples_prompt.metrics.metrics import exact_match
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
from opendelta.utils.name_based_addressing import * from opendelta.utils.name_based_addressing import *
from opendelta.utils.cuda import get_device from opendelta.utils.cuda import get_device
@ -47,6 +48,7 @@ class SoftPromptLayer(nn.Module):
soft_token_num: int = 100, soft_token_num: int = 100,
raw_embedding: Optional[torch.Tensor] = None, raw_embedding: Optional[torch.Tensor] = None,
init_range: Optional[float] = 0.5, init_range: Optional[float] = 0.5,
other_expand_ids: Optional[Dict] = {"attention_mask":1, "token_type_ids":0},
token_init = False, token_init = False,
pad_id = 0, pad_id = 0,
device: Optional[str]=None, device: Optional[str]=None,
@ -59,10 +61,13 @@ class SoftPromptLayer(nn.Module):
self.pad_id = pad_id self.pad_id = pad_id
self.token_init = token_init self.token_init = token_init
self.device = device self.device = device
self.other_expand_ids = other_expand_ids
assert self.num_tokens>0 assert self.num_tokens>0
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1]) self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
self.all_pseudo_tokens = {}
def pre_forward(self, *args, **kwargs): def pre_forward(self, *args, **kwargs):
# if attention_mask is passed as PLM's input, modify it here # if attention_mask is passed as PLM's input, modify it here
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None: if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
@ -100,8 +105,19 @@ class SoftPromptLayer(nn.Module):
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1) inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
kwargs['inputs_embeds'] = inputs_embeds kwargs['inputs_embeds'] = inputs_embeds
am = kwargs['attention_mask'] for expand_key in self.other_expand_ids:
am.data = torch.cat([torch.ones((*am.shape[:-1], inputs_embeds.shape[-2]-am.shape[-1]), dtype = am.dtype,device=am.device), am], dim=-1) if expand_key in kwargs:
real_tokens = kwargs[expand_key]
if expand_key in self.all_pseudo_tokens:
pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
else:
pseudo_tokens_value = self.other_expand_ids[expand_key]
pseudo_tokens = torch.ones(
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
dtype = real_tokens.dtype,
device=real_tokens.device) * pseudo_tokens_value
self.all_pseudo_tokens[expand_key] = pseudo_tokens
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
return args, kwargs return args, kwargs
@ -136,6 +152,10 @@ class SoftPromptModel(DeltaBase):
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input. soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution. init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
other_expand_ids (:obj:`dict`, *optional*, default to `{"attention_mask":1, "token_type_ids":0}`) The name of
other tokens and its default value that expand along with the input sequence. For example, when
you prepend 100 tokens to the input_ids, the attention_mask should be extended, and the token_type_ids should
be extended as well.
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
the implemented ones) the implemented ones)
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
@ -151,6 +171,7 @@ class SoftPromptModel(DeltaBase):
soft_token_num=100, soft_token_num=100,
init_range = 0.5, init_range = 0.5,
token_init=True, token_init=True,
other_expand_ids={"attention_mask":1, "token_type_ids":0},
modified_modules: Optional[List[str]] = None, modified_modules: Optional[List[str]] = None,
exclude_modules: Optional[List[str]] = None, exclude_modules: Optional[List[str]] = None,
unfrozen_modules: Optional[List[str]] = None, unfrozen_modules: Optional[List[str]] = None,
@ -202,6 +223,7 @@ class SoftPromptModel(DeltaBase):
soft_prompt_layer = SoftPromptLayer( soft_prompt_layer = SoftPromptLayer(
soft_token_num = self.soft_token_num, soft_token_num = self.soft_token_num,
raw_embedding = self.raw_embedding, raw_embedding = self.raw_embedding,
other_expand_ids = self.other_expand_ids,
token_init = self.token_init, token_init = self.token_init,
init_range = self.init_range, init_range = self.init_range,
device = module_device, device = module_device,

View File

@ -17,7 +17,7 @@ print(requires)
with open('README.md', 'r') as f: with open('README.md', 'r') as f:
setuptools.setup( setuptools.setup(
name = 'opendelta', name = 'opendelta',
version = '0.0.3', version = '0.0.4',
description = "An open source framework for delta learning (parameter efficient learning).", description = "An open source framework for delta learning (parameter efficient learning).",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",