example_prompt_unify
This commit is contained in:
parent
151e3dd2c2
commit
0de3cbd31d
Binary file not shown.
Binary file not shown.
|
@ -1,3 +1,3 @@
|
||||||
from .tasks import TASK_MAPPING, AutoTask
|
from .tasks import TASK_MAPPING, AutoTask
|
||||||
from .data_collator import TaskDataCollatorForSeq2Seq
|
# from .data_collator import TaskDataCollatorForSeq2Seq
|
||||||
from .postprocessors import AutoPostProcessor
|
# from .postprocessors import AutoPostProcessor
|
||||||
|
|
|
@ -1,28 +1,28 @@
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
from dataclasses import dataclass
|
# from dataclasses import dataclass
|
||||||
from transformers import DataCollatorForSeq2Seq
|
# from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
# @dataclass
|
||||||
class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
# class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||||
def check_uniqueness(self, samples):
|
|
||||||
assert len(np.unique(samples)) == 1
|
|
||||||
|
|
||||||
def __call__(self, features):
|
|
||||||
# tasks = [d.pop('task') for d in features]
|
|
||||||
# self.check_uniqueness(tasks)
|
|
||||||
output = super().__call__(features)
|
|
||||||
# output["task"] = tasks[0]
|
|
||||||
return output
|
|
||||||
|
|
||||||
# class CustomDataCollator(DefaultDataCollator):
|
|
||||||
# def check_uniqueness(self, samples):
|
# def check_uniqueness(self, samples):
|
||||||
# assert len(np.unique(samples)) == 1
|
# assert len(np.unique(samples)) == 1
|
||||||
|
|
||||||
# def __call__(self, features):
|
# def __call__(self, features):
|
||||||
# mask_positions = [d.pop('mask_positions') for d in features]
|
# # tasks = [d.pop('task') for d in features]
|
||||||
# # self.check_uniqueness(tasks)
|
# # self.check_uniqueness(tasks)
|
||||||
# output = super().__call__(features)
|
# output = super().__call__(features)
|
||||||
|
|
||||||
# # output["task"] = tasks[0]
|
# # output["task"] = tasks[0]
|
||||||
# return output
|
# return output
|
||||||
|
|
||||||
|
# # class CustomDataCollator(DefaultDataCollator):
|
||||||
|
# # def check_uniqueness(self, samples):
|
||||||
|
# # assert len(np.unique(samples)) == 1
|
||||||
|
|
||||||
|
# # def __call__(self, features):
|
||||||
|
# # mask_positions = [d.pop('mask_positions') for d in features]
|
||||||
|
# # # self.check_uniqueness(tasks)
|
||||||
|
# # output = super().__call__(features)
|
||||||
|
|
||||||
|
# # # output["task"] = tasks[0]
|
||||||
|
# # return output
|
|
@ -15,6 +15,7 @@ import re
|
||||||
from openprompt.prompts import ManualTemplate, ManualVerbalizer
|
from openprompt.prompts import ManualTemplate, ManualVerbalizer
|
||||||
from openprompt.plms.utils import TokenizerWrapper
|
from openprompt.plms.utils import TokenizerWrapper
|
||||||
from openprompt.data_utils import InputExample
|
from openprompt.data_utils import InputExample
|
||||||
|
from openprompt.prompts import GenerationVerbalizer
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,184 +29,170 @@ from typing import List, Dict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from openprompt.utils import round_list
|
from openprompt.utils import round_list
|
||||||
import warnings
|
import warnings
|
||||||
class MLMTokenizerWrapper:
|
|
||||||
def __init__(self, max_seq_length, tokenizer, truncate_method):
|
# class MLMTokenizerWrapper:
|
||||||
self.max_seq_length=max_seq_length
|
# def __init__(self, max_seq_length, tokenizer, truncate_method, mask_token_func=lambda i: "<mask>"):
|
||||||
self.tokenizer=tokenizer
|
# self.max_seq_length=max_seq_length
|
||||||
self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
|
# self.tokenizer=tokenizer
|
||||||
# from IPython import embed; embed(header="Truega")
|
# self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
|
||||||
self.truncate_method=truncate_method
|
# # from IPython import embed; embed(header="Truega")
|
||||||
self.total_passed_sentences = 0
|
# self.truncate_method=truncate_method
|
||||||
self.num_truncated_sentences = 0
|
# self.total_passed_sentences = 0
|
||||||
if truncate_method=='tail':
|
# self.num_truncated_sentences = 0
|
||||||
self.truncate_fct = self.truncate_from_tail
|
# self.mask_token_func = mask_token_func
|
||||||
elif truncate_method=='head':
|
|
||||||
self.truncate_fct = self.truncate_from_head
|
# if truncate_method=='tail':
|
||||||
elif truncate_method == 'balanced':
|
# self.truncate_fct = self.truncate_from_tail
|
||||||
self.truncate_fct = self.balanced_truncate
|
# elif truncate_method=='head':
|
||||||
else:
|
# self.truncate_fct = self.truncate_from_head
|
||||||
raise NotImplementedError
|
# elif truncate_method == 'balanced':
|
||||||
|
# self.truncate_fct = self.balanced_truncate
|
||||||
|
# else:
|
||||||
|
# raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# def merge_wrapped_example(self, wrapped_example,):
|
||||||
|
# ''' # TODO doens't consider the situation that input has two parts
|
||||||
|
# '''
|
||||||
|
|
||||||
|
# wrapped_example
|
||||||
|
|
||||||
|
# # for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
|
||||||
|
# # the input. Or in generation tasks, we need to generate a piece of target_text.
|
||||||
|
# # In these case, it tokenized to the encoded_tgt_text for furture use.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# encoder_inputs = defaultdict(list)
|
||||||
|
# # from IPython import embed; embed(header="Line 67")
|
||||||
|
|
||||||
|
# mask_count = 0
|
||||||
|
# for piece in wrapped_example:
|
||||||
|
# if piece['text'] == "<mask>":
|
||||||
|
# encode_text = self.tokenizer.encode(self.mask_token_func(mask_count), add_special_tokens=False, return_special_tokens_mask=True )
|
||||||
|
# mask_count += 1
|
||||||
|
# else:
|
||||||
|
# encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
|
||||||
|
# encoder_inputs['input_ids'].append(encode_text)
|
||||||
|
# encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
|
||||||
|
|
||||||
|
|
||||||
def merge_wrapped_example(self, wrapped_example, ):
|
# encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
|
||||||
''' # TODO doens't consider the situation that input has two parts
|
# encoder_inputs.pop("shortenable_ids")
|
||||||
'''
|
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
||||||
|
# decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
|
||||||
|
|
||||||
wrapped_example, others = wrapped_example
|
# return decoded_inputs
|
||||||
|
|
||||||
# for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
|
|
||||||
# the input. Or in generation tasks, we need to generate a piece of target_text.
|
# @staticmethod
|
||||||
# In these case, it tokenized to the encoded_tgt_text for furture use.
|
# def balanced_truncate(input_dict: Dict,
|
||||||
|
# num_tokens_to_truncate: int=0) -> Dict:
|
||||||
|
# '''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
|
||||||
|
# '''
|
||||||
|
# shortenable_lens = [len(parts) if parts[0]==1 else 0
|
||||||
|
# for parts in input_dict['shortenable_ids']]
|
||||||
|
# total_shortenable_len = sum(shortenable_lens)
|
||||||
|
# num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
|
||||||
|
# for part_len in shortenable_lens]
|
||||||
|
# round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
|
||||||
|
|
||||||
|
# truncated_example = defaultdict(list)
|
||||||
|
# for key in input_dict:
|
||||||
|
# parts = input_dict[key]
|
||||||
|
# for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
|
||||||
|
# truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
|
||||||
|
# return truncated_example
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def truncate_from_tail(input_dict: Dict,
|
||||||
|
# num_tokens_to_truncate: int=0) -> Dict:
|
||||||
|
# r"""truncate the inputs from the rear
|
||||||
|
# """
|
||||||
|
# truncated_example = defaultdict(list)
|
||||||
|
# shortenable_ids = input_dict['shortenable_ids']
|
||||||
|
|
||||||
|
# for key in input_dict:
|
||||||
|
# parts = input_dict[key]
|
||||||
|
# to_trunc = num_tokens_to_truncate
|
||||||
|
# for i, part in enumerate(parts[::-1]):
|
||||||
|
# if len(part) == 0: # to prevent some part are empty after tokenization
|
||||||
|
# continue
|
||||||
|
# if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
|
||||||
|
# continue
|
||||||
|
# parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||||
|
# to_trunc -= len(part)
|
||||||
|
# if to_trunc <= 0:
|
||||||
|
# break
|
||||||
|
# truncated_example[key] = parts
|
||||||
|
# return truncated_example
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def truncate_from_head(input_dict: Dict,
|
||||||
|
# num_tokens_to_truncate: int=0) -> Dict:
|
||||||
|
# r"""truncate the inputs from the head
|
||||||
|
# """
|
||||||
|
# truncated_example = defaultdict(list)
|
||||||
|
# shortenable_ids = input_dict['shortenable_ids']
|
||||||
|
# for key in input_dict:
|
||||||
|
# parts = input_dict[key]
|
||||||
|
# to_trunc = num_tokens_to_truncate
|
||||||
|
# for i, part in enumerate(parts):
|
||||||
|
# if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
|
||||||
|
# continue
|
||||||
|
# parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||||
|
# to_trunc -= len(part)
|
||||||
|
# if to_trunc <= 0:
|
||||||
|
# break
|
||||||
|
# truncated_example[key] = parts
|
||||||
|
# return truncated_example
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def concate_parts(input_dict: Dict) -> Dict:
|
||||||
|
# for key in input_dict:
|
||||||
|
# input_dict[key] = list(itertools.chain(*input_dict[key]))
|
||||||
|
# return input_dict
|
||||||
|
|
||||||
|
|
||||||
|
# def truncate(self, encoder_inputs):
|
||||||
|
# total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
|
||||||
|
# num_specials = self.num_special_tokens_to_add
|
||||||
|
# # print("num_specials", num_specials)
|
||||||
|
# num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
|
||||||
|
# self.total_passed_sentences+=1
|
||||||
|
# if num_tokens_to_truncate>0:
|
||||||
|
# self.num_truncated_sentences += 1
|
||||||
|
# if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
|
||||||
|
# raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
|
||||||
|
# encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
|
||||||
|
# num_tokens_to_truncate=num_tokens_to_truncate)
|
||||||
|
# return encoder_inputs
|
||||||
|
|
||||||
|
# def tokenizer_preprocessor(self, example):
|
||||||
|
# # source, target = example
|
||||||
|
# # from IPython import embed; embed(header="Trehre2")
|
||||||
|
# label = example['label']
|
||||||
|
# guid = example['idx']
|
||||||
|
# meta = dict(example)
|
||||||
|
# meta.pop("label")
|
||||||
|
# meta.pop("idx")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
encoder_inputs = defaultdict(list)
|
# # from IPython import embed; embed(header="Trehre2")
|
||||||
for piece in wrapped_example:
|
|
||||||
encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
|
|
||||||
encoder_inputs['input_ids'].append(encode_text)
|
|
||||||
encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
|
|
||||||
|
|
||||||
|
# e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
|
||||||
|
|
||||||
encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
|
# if self.predict_with_generate:
|
||||||
encoder_inputs.pop("shortenable_ids")
|
# e = self.verbalizer.wrap_one_example(e)
|
||||||
encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
# example_wrapped = self.template.wrap_one_example(e)
|
||||||
decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
|
# encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(example_wrapped)
|
||||||
|
# print(encoded_sentence)
|
||||||
# again_encode = self.tokenizer.encode(decoded_inputs, add_special_tokens=False, return_special_tokens_mask=True)
|
# if self.predict_with_generate:
|
||||||
# if len(again_encode)> self.max_seq_length - 2:
|
# # return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
|
||||||
# print("length exceed!")
|
# return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
|
||||||
# print(wrapped_example)
|
# else:
|
||||||
# print(encoder_inputs['input_ids'])
|
# return {"source": encoded_sentence, "label": label, 'target': e.target_text, 'extra_fields':{'dataset_name':self.name}}
|
||||||
# print(again_encode)
|
|
||||||
# print(decoded_inputs)
|
|
||||||
# exit()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# delete shortenable ids
|
|
||||||
|
|
||||||
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
|
||||||
# encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs)
|
|
||||||
# # create special input ids
|
|
||||||
# encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids'])
|
|
||||||
# # padding
|
|
||||||
# encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id)
|
|
||||||
|
|
||||||
return decoded_inputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def balanced_truncate(input_dict: Dict,
|
|
||||||
num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
'''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
|
|
||||||
'''
|
|
||||||
shortenable_lens = [len(parts) if parts[0]==1 else 0
|
|
||||||
for parts in input_dict['shortenable_ids']]
|
|
||||||
total_shortenable_len = sum(shortenable_lens)
|
|
||||||
num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
|
|
||||||
for part_len in shortenable_lens]
|
|
||||||
round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
|
|
||||||
|
|
||||||
truncated_example = defaultdict(list)
|
|
||||||
for key in input_dict:
|
|
||||||
parts = input_dict[key]
|
|
||||||
for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
|
|
||||||
truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
|
|
||||||
return truncated_example
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def truncate_from_tail(input_dict: Dict,
|
|
||||||
num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
r"""truncate the inputs from the rear
|
|
||||||
"""
|
|
||||||
truncated_example = defaultdict(list)
|
|
||||||
shortenable_ids = input_dict['shortenable_ids']
|
|
||||||
|
|
||||||
for key in input_dict:
|
|
||||||
parts = input_dict[key]
|
|
||||||
to_trunc = num_tokens_to_truncate
|
|
||||||
for i, part in enumerate(parts[::-1]):
|
|
||||||
if len(part) == 0: # to prevent some part are empty after tokenization
|
|
||||||
continue
|
|
||||||
if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
|
|
||||||
continue
|
|
||||||
parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
|
|
||||||
to_trunc -= len(part)
|
|
||||||
if to_trunc <= 0:
|
|
||||||
break
|
|
||||||
truncated_example[key] = parts
|
|
||||||
return truncated_example
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def truncate_from_head(input_dict: Dict,
|
|
||||||
num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
r"""truncate the inputs from the head
|
|
||||||
"""
|
|
||||||
truncated_example = defaultdict(list)
|
|
||||||
shortenable_ids = input_dict['shortenable_ids']
|
|
||||||
for key in input_dict:
|
|
||||||
parts = input_dict[key]
|
|
||||||
to_trunc = num_tokens_to_truncate
|
|
||||||
for i, part in enumerate(parts):
|
|
||||||
if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
|
|
||||||
continue
|
|
||||||
parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
|
|
||||||
to_trunc -= len(part)
|
|
||||||
if to_trunc <= 0:
|
|
||||||
break
|
|
||||||
truncated_example[key] = parts
|
|
||||||
return truncated_example
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def concate_parts(input_dict: Dict) -> Dict:
|
|
||||||
for key in input_dict:
|
|
||||||
input_dict[key] = list(itertools.chain(*input_dict[key]))
|
|
||||||
return input_dict
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def padding(input_dict: Dict,
|
|
||||||
# max_len: int, pad_id_for_inputs: int=0, pad_id_for_others: int=0) -> None:
|
|
||||||
# for key, value in input_dict.items():
|
|
||||||
# if (len(input_dict[key]) > max_len):
|
|
||||||
# raise ValueError(f'''
|
|
||||||
# Truncated seq length of '{key}' still greater than max length '{max_len}.'
|
|
||||||
# One possible reason is that no enough shortenable parts in template. Try add {{"shortenable": "True"}} property.
|
|
||||||
# ''')
|
|
||||||
# if 'input' in key:
|
|
||||||
# input_dict[key].extend([pad_id_for_inputs]*(max_len-len(value)))
|
|
||||||
# else:
|
|
||||||
# input_dict[key].extend([pad_id_for_others]*(max_len-len(value)))
|
|
||||||
# return input_dict
|
|
||||||
|
|
||||||
|
|
||||||
# def add_special_tokens(self, encoder_inputs):
|
|
||||||
# # add special tokens
|
|
||||||
# for key in encoder_inputs:
|
|
||||||
# if key == "input_ids":
|
|
||||||
# with warnings.catch_warnings():
|
|
||||||
# warnings.simplefilter("ignore")
|
|
||||||
# encoder_inputs[key] = self.tokenizer.build_inputs_with_special_tokens(
|
|
||||||
# encoder_inputs[key])
|
|
||||||
# return encoder_inputs
|
|
||||||
|
|
||||||
def truncate(self, encoder_inputs):
|
|
||||||
total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
|
|
||||||
num_specials = self.num_special_tokens_to_add
|
|
||||||
# print("num_specials", num_specials)
|
|
||||||
num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
|
|
||||||
self.total_passed_sentences+=1
|
|
||||||
if num_tokens_to_truncate>0:
|
|
||||||
self.num_truncated_sentences += 1
|
|
||||||
if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
|
|
||||||
raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
|
|
||||||
encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
|
|
||||||
num_tokens_to_truncate=num_tokens_to_truncate)
|
|
||||||
return encoder_inputs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -234,46 +221,23 @@ class AbstractTask(abc.ABC):
|
||||||
"superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"]
|
"superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"]
|
||||||
large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"]
|
large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"]
|
||||||
|
|
||||||
def __init__(self, config, data_args, tokenizer, predict_with_generate, seed=42, default_max_length=1):
|
def __init__(self, config, data_args, seed=42, default_max_length=1):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.data_args = data_args
|
self.data_args = data_args
|
||||||
self.tokenizer = tokenizer
|
# self.tokenizer = tokenizer
|
||||||
self.predict_with_generate = predict_with_generate
|
# self.predict_with_generate = predict_with_generate
|
||||||
self.default_max_length = default_max_length
|
self.default_max_length = default_max_length
|
||||||
self.truncate_method = getattr(data_args, "truncate_method", "balanced")
|
|
||||||
|
|
||||||
tid = getattr(config, "template_id", 0)
|
# generation_paradigm = getattr(config, "generation_paradigm", True)
|
||||||
vid = getattr(config, "verbalizer_id", 0)
|
|
||||||
|
|
||||||
self.template = ManualTemplate(tokenizer=self.tokenizer, text = self.templates_text[tid])
|
|
||||||
self.verbalizer = ManualVerbalizer(tokenizer=self.tokenizer, classes = self.labels_list, label_words=self.verbalizers[vid])
|
|
||||||
|
|
||||||
# if self.predict_with_generate:
|
|
||||||
# self.reverse_verbalizer = {(int(x) for x in self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))): label for label in self.labels_list}
|
|
||||||
# else:
|
|
||||||
# self.reverse_verbalizer = {int(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))[0]): label for label in self.labels_list}
|
|
||||||
|
|
||||||
self.tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=self.data_args.max_source_length, tokenizer=self.tokenizer, truncate_method=self.truncate_method)
|
|
||||||
|
|
||||||
generation_paradigm = getattr(config, "generation_paradigm", True)
|
|
||||||
# self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm)
|
# self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm)
|
||||||
self.max_target_length = self.get_max_target_length(self.default_max_length)
|
|
||||||
|
|
||||||
def get_max_target_length(self, default_max_length):
|
|
||||||
if self.predict_with_generate:
|
|
||||||
return max([len(label) for key, label in self.verbalizer.label_words_ids.items()])
|
|
||||||
else:
|
|
||||||
return default_max_length
|
|
||||||
|
|
||||||
def seq2seq_format(self, source, target, extra_fields={}
|
# def get_max_target_length(self, default_max_length):
|
||||||
):
|
# if self.predict_with_generate:
|
||||||
|
# return -1
|
||||||
return {'source': ' '.join(source),
|
# else:
|
||||||
'target': ' '.join(target),
|
# return default_max_length
|
||||||
'task': self.name,
|
|
||||||
'extra_fields': extra_fields
|
|
||||||
}
|
|
||||||
|
|
||||||
def check_n_obs(self, n_obs, total_size):
|
def check_n_obs(self, n_obs, total_size):
|
||||||
if n_obs is not None and n_obs > total_size:
|
if n_obs is not None and n_obs > total_size:
|
||||||
|
@ -312,37 +276,9 @@ class AbstractTask(abc.ABC):
|
||||||
else:
|
else:
|
||||||
return indices[validation_size:]
|
return indices[validation_size:]
|
||||||
|
|
||||||
|
|
||||||
def map_dataset(self, dataset):
|
|
||||||
# from IPython import embed; embed(header="in get target length")
|
|
||||||
return dataset.map(self.preprocessor).map(self.tokenizer_preprocessor)
|
|
||||||
|
|
||||||
def preprocessor(self, example):
|
def preprocessor(self, example):
|
||||||
return example
|
return example
|
||||||
|
|
||||||
def tokenizer_preprocessor(self, example):
|
|
||||||
# source, target = example
|
|
||||||
# from IPython import embed; embed(header="Trehre2")
|
|
||||||
label = example['label']
|
|
||||||
guid = example['idx']
|
|
||||||
meta = dict(example)
|
|
||||||
meta.pop("label")
|
|
||||||
meta.pop("idx")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="Trehre2")
|
|
||||||
|
|
||||||
e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
|
|
||||||
template_e = self.template.wrap_one_example(e)
|
|
||||||
encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(template_e)
|
|
||||||
if self.predict_with_generate:
|
|
||||||
# return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
|
|
||||||
raise NotImplementedError
|
|
||||||
else:
|
|
||||||
return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
|
|
||||||
|
|
||||||
|
|
||||||
def get(self, split, n_obs=None, split_validation_test=False):
|
def get(self, split, n_obs=None, split_validation_test=False):
|
||||||
# For small datasets (n_samples < 10K) without test set, we divide validation set to
|
# For small datasets (n_samples < 10K) without test set, we divide validation set to
|
||||||
# half, use one half as test set and one half as validation set.
|
# half, use one half as test set and one half as validation set.
|
||||||
|
@ -368,7 +304,7 @@ class AbstractTask(abc.ABC):
|
||||||
# shuffles the data and samples it.
|
# shuffles the data and samples it.
|
||||||
if n_obs is not None:
|
if n_obs is not None:
|
||||||
dataset = self.subsample(dataset, n_obs)
|
dataset = self.subsample(dataset, n_obs)
|
||||||
return self.map_dataset(dataset)
|
return dataset.map(self.preprocessor)
|
||||||
|
|
||||||
class Squad(AbstractTask):
|
class Squad(AbstractTask):
|
||||||
name = "squad"
|
name = "squad"
|
||||||
|
@ -387,25 +323,7 @@ class Squad(AbstractTask):
|
||||||
return self.seq2seq_format(source, target, add_prefix)
|
return self.seq2seq_format(source, target, add_prefix)
|
||||||
|
|
||||||
|
|
||||||
class MRPC(AbstractTask):
|
##GLUE
|
||||||
name = "mrpc"
|
|
||||||
labels_list = ["0", "1"]
|
|
||||||
metric = [metrics.f1_score, metrics.accuracy]
|
|
||||||
metric_names = ["f1", "accuracy"]
|
|
||||||
split_to_data_split = {"train": "train",
|
|
||||||
"validation": "validation",
|
|
||||||
"test": "validation"}
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
|
||||||
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
|
|
||||||
|
|
||||||
# def preprocessor(self, example, add_prefix=True):
|
|
||||||
# src_texts = ["sentence1:", example['sentence1'],
|
|
||||||
# "sentence2:", example["sentence2"]]
|
|
||||||
# tgt_texts = [str(example['label'])]
|
|
||||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
class COLA(AbstractTask):
|
class COLA(AbstractTask):
|
||||||
name = "cola"
|
name = "cola"
|
||||||
labels_list = ["0", "1"]
|
labels_list = ["0", "1"]
|
||||||
|
@ -415,15 +333,20 @@ class COLA(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
|
templates_text = {"0": """sentence: {"meta": 'sentence', "shortenable":True} Are there any error in the sentence? {"mask"}""",
|
||||||
|
}
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0":{ "0": "yes", "1": "no"}
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.cola")[split]
|
||||||
|
else:
|
||||||
return datasets.load_dataset('glue', 'cola',
|
return datasets.load_dataset('glue', 'cola',
|
||||||
split=split, script_version="master")
|
split=split, script_version="master")
|
||||||
|
|
||||||
# def preprocessor(self, example, add_prefix=True):
|
|
||||||
# src_texts = ["sentence:", example['sentence']]
|
|
||||||
# tgt_texts = [str(example['label'])]
|
|
||||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
class SST2(AbstractTask):
|
class SST2(AbstractTask):
|
||||||
name = "sst2"
|
name = "sst2"
|
||||||
|
@ -434,38 +357,50 @@ class SST2(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
verbalizers = [
|
verbalizers = {
|
||||||
|
"0":{"0":"negative","1":"positive"}
|
||||||
|
}
|
||||||
|
|
||||||
]
|
templates_text = {
|
||||||
|
"0":"""The sentiment of sentence: "{"meta":"sentence", "shortenable":True} is {"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.sst2")[split]
|
||||||
|
else:
|
||||||
return datasets.load_dataset('glue', 'sst2',
|
return datasets.load_dataset('glue', 'sst2',
|
||||||
split=split, script_version="master")
|
split=split, script_version="master")
|
||||||
|
|
||||||
def preprocessor(self, example, add_prefix=True):
|
|
||||||
src_texts = ["sentence:", example['sentence']]
|
|
||||||
tgt_texts = [str(example['label'])]
|
|
||||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
class STSB(AbstractTask):
|
class MRPC(AbstractTask):
|
||||||
name = "stsb"
|
name = "mrpc"
|
||||||
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)]
|
labels_list = ["0", "1"]
|
||||||
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
|
metric = [metrics.f1_score, metrics.accuracy]
|
||||||
metric_names = ["pearson", "spearmanr"]
|
metric_names = ["f1", "accuracy"]
|
||||||
split_to_data_split = {"train": "train",
|
split_to_data_split = {"train": "train",
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
|
||||||
return datasets.load_dataset('glue', 'stsb',
|
|
||||||
split=split, script_version="master")
|
|
||||||
|
|
||||||
def preprocessor(self, example, add_prefix=True):
|
templates_text = {
|
||||||
src_texts = ["sentence1:", example['sentence1'],
|
"0": """sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2: {"meta":"sentence2", "shortenable":True}. Are sentence1 and sentence2 equivalent? {"mask"}.""",
|
||||||
"sentence2:", example["sentence2"]]
|
}
|
||||||
tgt_texts = [str(round_stsb_target(example['label']))]
|
|
||||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
verbalizers = {
|
||||||
|
"0":{"0": "no","1": "yes"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mrpc")[split]
|
||||||
|
else:
|
||||||
|
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QQP(AbstractTask):
|
class QQP(AbstractTask):
|
||||||
|
@ -477,14 +412,46 @@ class QQP(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
|
templates_text = {"0":
|
||||||
|
"""question1: {"meta": 'question1', "shortenable":True}. question2: {"meta": 'question2', "shortenable":True} Are question1 and question2 equivalent? {"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0":{"0": "no","1": "yes"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qqp")[split]
|
||||||
|
else:
|
||||||
return datasets.load_dataset('glue', 'qqp',
|
return datasets.load_dataset('glue', 'qqp',
|
||||||
split=split, script_version="master")
|
split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class STSB(AbstractTask):
|
||||||
|
name = "stsb"
|
||||||
|
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)]
|
||||||
|
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
|
||||||
|
metric_names = ["pearson", "spearmanr"]
|
||||||
|
split_to_data_split = {"train": "train",
|
||||||
|
"validation": "validation",
|
||||||
|
"test": "validation"}
|
||||||
|
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_dataset(self, split):
|
||||||
|
return datasets.load_dataset('glue', 'stsb',
|
||||||
|
split=split, script_version="master")
|
||||||
|
|
||||||
def preprocessor(self, example, add_prefix=True):
|
def preprocessor(self, example, add_prefix=True):
|
||||||
src_texts = ["question1:", example['question1'],
|
src_texts = ["sentence1:", example['sentence1'],
|
||||||
"question2:", example["question2"]]
|
"sentence2:", example["sentence2"]]
|
||||||
tgt_texts = [str(example['label'])]
|
tgt_texts = [str(round_stsb_target(example['label']))]
|
||||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@ -498,14 +465,29 @@ class MNLI(AbstractTask):
|
||||||
metric_names = ["accuracy"]
|
metric_names = ["accuracy"]
|
||||||
|
|
||||||
|
|
||||||
|
templates_text = {
|
||||||
|
"0":"""premise: {"meta": 'premise', "shortenable":True}. hypothesis: {"meta": 'hypothesis', "shortenable":True} Does the premise entails the hypothesis? {"mask"}.""",
|
||||||
|
}
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0":{
|
||||||
|
"0": "yes",
|
||||||
|
"1": "neutral",
|
||||||
|
"2": "no",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mnli")[split]
|
||||||
|
else:
|
||||||
return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")
|
return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")
|
||||||
|
|
||||||
def preprocessor(self, example, add_prefix=True):
|
# def preprocessor(self, example, add_prefix=True):
|
||||||
src_texts = ["premise:", example['premise'],
|
# src_texts = ["premise:", example['premise'],
|
||||||
"hypothesis", example["hypothesis"]]
|
# "hypothesis", example["hypothesis"]]
|
||||||
tgt_texts = [str(example['label'])]
|
# tgt_texts = [str(example['label'])]
|
||||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||||
|
|
||||||
|
|
||||||
class QNLI(AbstractTask):
|
class QNLI(AbstractTask):
|
||||||
|
@ -517,14 +499,33 @@ class QNLI(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
|
templates_text = {
|
||||||
|
"0": """premise: {"meta": 'sentence', "shortenable":True}. hypothesis: {"meta": 'question', "shortenable":True}"""+
|
||||||
|
"""Does the premise entails the hypothesis? {"mask"}.""",
|
||||||
|
}
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0":{
|
||||||
|
"0": "yes",
|
||||||
|
"1": "no",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qnli")[split]
|
||||||
|
else:
|
||||||
return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
||||||
|
|
||||||
def preprocessor(self, example, add_prefix=True):
|
# def load_dataset(self, split):
|
||||||
src_texts = ["question:", example['question'],
|
# return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
||||||
"sentence:", example["sentence"]]
|
|
||||||
tgt_texts = [str(example['label'])]
|
# def preprocessor(self, example, add_prefix=True):
|
||||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
# src_texts = ["question:", example['question'],
|
||||||
|
# "sentence:", example["sentence"]]
|
||||||
|
# tgt_texts = [str(example['label'])]
|
||||||
|
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||||
|
|
||||||
#Tested
|
#Tested
|
||||||
class RTE(AbstractTask):
|
class RTE(AbstractTask):
|
||||||
|
@ -537,15 +538,15 @@ class RTE(AbstractTask):
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
|
|
||||||
templates_text = [
|
templates_text = {
|
||||||
"""sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2:,"""+
|
"0": """sentence1: {"meta": 'sentence1', "shortenable":True} sentence2: {"meta":"sentence2", "shortenable":True} The answer was {"mask"}.""",
|
||||||
"""{"meta":"sentence2", "shortenable":True}. The answer was {"mask"}.""",
|
}
|
||||||
]
|
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
"0": "yes",
|
"0":{"0": "yes",
|
||||||
"1": "no"
|
"1": "no"
|
||||||
}]
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
if self.data_args.datasets_load_from_disk:
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
@ -555,38 +556,6 @@ class RTE(AbstractTask):
|
||||||
split=split, script_version="master")
|
split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
#Tested
|
|
||||||
class SuperGLUEBoolQ(AbstractTask):
|
|
||||||
name="superglue-boolq"
|
|
||||||
labels_list = ['0', '1']
|
|
||||||
metric = [metrics.accuracy]
|
|
||||||
metric_names = ["accuracy"]
|
|
||||||
split_to_data_split = {"train": "train",
|
|
||||||
"validation": "validation",
|
|
||||||
"test": "validation"}
|
|
||||||
|
|
||||||
verbalizers = [
|
|
||||||
{
|
|
||||||
"0": "no",
|
|
||||||
"1": "yes"
|
|
||||||
},
|
|
||||||
]
|
|
||||||
mlmhead_verbalizers = {
|
|
||||||
"0": "no",
|
|
||||||
"1": "yes"
|
|
||||||
}
|
|
||||||
templates_text = [
|
|
||||||
"""hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
|
|
||||||
]
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
|
||||||
if self.data_args.datasets_load_from_disk:
|
|
||||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
|
|
||||||
else:
|
|
||||||
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WNLI(AbstractTask):
|
class WNLI(AbstractTask):
|
||||||
name = "wnli"
|
name = "wnli"
|
||||||
|
@ -597,13 +566,13 @@ class WNLI(AbstractTask):
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
"0": "True",
|
"0":{"0": "True",
|
||||||
"1": "False",
|
"1": "False",
|
||||||
}]
|
}
|
||||||
templates_text = [
|
}
|
||||||
"""{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}."""
|
templates_text = {"0": """{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}."""
|
||||||
]
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
@ -613,6 +582,34 @@ class WNLI(AbstractTask):
|
||||||
return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")
|
return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
|
#SuperGLUE
|
||||||
|
class SuperGLUEBoolQ(AbstractTask):
|
||||||
|
name="superglue-boolq"
|
||||||
|
labels_list = ['0', '1']
|
||||||
|
metric = [metrics.accuracy]
|
||||||
|
metric_names = ["accuracy"]
|
||||||
|
split_to_data_split = {"train": "train",
|
||||||
|
"validation": "validation",
|
||||||
|
"test": "validation"}
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0": {
|
||||||
|
"0": "no",
|
||||||
|
"1": "yes"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
templates_text = {
|
||||||
|
"0": """hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_dataset(self, split):
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
|
||||||
|
else:
|
||||||
|
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
class SuperGLUECB(AbstractTask):
|
class SuperGLUECB(AbstractTask):
|
||||||
name = "superglue-cb"
|
name = "superglue-cb"
|
||||||
|
@ -623,14 +620,15 @@ class SuperGLUECB(AbstractTask):
|
||||||
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
||||||
metric_names = ["f1_multiclass", "accuracy"]
|
metric_names = ["f1_multiclass", "accuracy"]
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
"0": "yes",
|
"0":{"0": "yes",
|
||||||
"1": "no",
|
"1": "no",
|
||||||
"2": "maybe"
|
"2": "maybe"
|
||||||
}]
|
}
|
||||||
templates_text = [
|
}
|
||||||
"""hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}."""
|
templates_text = {
|
||||||
]
|
"0": """hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
if self.data_args.datasets_load_from_disk:
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
@ -648,13 +646,15 @@ class SuperGLUECOPA(AbstractTask):
|
||||||
metric = [metrics.accuracy]
|
metric = [metrics.accuracy]
|
||||||
metric_names = ["accuracy"]
|
metric_names = ["accuracy"]
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
|
"0":{
|
||||||
"0": "1",
|
"0": "1",
|
||||||
"1": "2",
|
"1": "2",
|
||||||
}]
|
}
|
||||||
templates_text = [
|
}
|
||||||
"""choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}."""
|
templates_text = {
|
||||||
]
|
"0": """choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
if self.data_args.datasets_load_from_disk:
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
@ -673,19 +673,16 @@ class SuperGLUEMultiRC(AbstractTask):
|
||||||
metrics.accuracy]
|
metrics.accuracy]
|
||||||
metric_names = ["f1", "em"]
|
metric_names = ["f1", "em"]
|
||||||
|
|
||||||
# generation_verbalizers = [{
|
|
||||||
# "0": "no",
|
|
||||||
# "1": "yes",
|
|
||||||
# },
|
|
||||||
# ]
|
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
|
"0": {
|
||||||
"0": "no",
|
"0": "no",
|
||||||
"1": "yes",
|
"1": "yes",
|
||||||
}]
|
}
|
||||||
templates_text = [
|
}
|
||||||
"""question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}."""
|
templates_text = {
|
||||||
]
|
"0": """question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}."""
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
|
@ -720,15 +717,16 @@ class SuperGLUEWIC(AbstractTask):
|
||||||
metric = [metrics.accuracy]
|
metric = [metrics.accuracy]
|
||||||
metric_names = ["accuracy"]
|
metric_names = ["accuracy"]
|
||||||
|
|
||||||
verbalizers = [{
|
verbalizers = {
|
||||||
|
"0": {
|
||||||
"0": "No",
|
"0": "No",
|
||||||
"1": "Yes",
|
"1": "Yes",
|
||||||
}]
|
}
|
||||||
|
}
|
||||||
|
|
||||||
templates_text = [
|
templates_text = {
|
||||||
"""sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}.
|
"0": """sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}."""
|
||||||
"""
|
}
|
||||||
]
|
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
if self.data_args.datasets_load_from_disk:
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
@ -786,7 +784,7 @@ class SuperGLUEWIC(AbstractTask):
|
||||||
# text = self._mark_span(text, example['span2_text'], span2_index, '#')
|
# text = self._mark_span(text, example['span2_text'], span2_index, '#')
|
||||||
# src_texts = ["text:", text]
|
# src_texts = ["text:", text]
|
||||||
# tgt_texts = [str(example["label"])]
|
# tgt_texts = [str(example["label"])]
|
||||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
# return self.fseq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||||
|
|
||||||
|
|
||||||
class SuperGLUERecord(AbstractTask):
|
class SuperGLUERecord(AbstractTask):
|
||||||
|
@ -875,9 +873,9 @@ TASK_MAPPING = OrderedDict(
|
||||||
|
|
||||||
class AutoTask:
|
class AutoTask:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(self, task, config, data_args, tokenizer,predict_with_generate, seed=42):
|
def get(self, task, config, data_args, seed=42):
|
||||||
if task in TASK_MAPPING:
|
if task in TASK_MAPPING:
|
||||||
return TASK_MAPPING[task](config, data_args, tokenizer,predict_with_generate, seed)
|
return TASK_MAPPING[task](config, data_args, seed)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized task {} for AutoTask Model: {}.\n"
|
"Unrecognized task {} for AutoTask Model: {}.\n"
|
||||||
"Task name should be one of {}.".format(
|
"Task name should be one of {}.".format(
|
||||||
|
|
|
@ -34,6 +34,7 @@ from transformers import (
|
||||||
AutoModelForMaskedLM,
|
AutoModelForMaskedLM,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
MBartTokenizer,
|
MBartTokenizer,
|
||||||
default_data_collator,
|
default_data_collator,
|
||||||
|
@ -41,8 +42,8 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
||||||
# from ..seq2seq.utils import get_adapter_config
|
# from ..seq2seq.utils import get_adapter_config
|
||||||
from examples_prompt.data_processors import AutoTask, TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
from examples_prompt.data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
||||||
from examples_prompt.seq2seq_trainer import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
# from training_args import AdapterTrainingArguments
|
# from training_args import AdapterTrainingArguments
|
||||||
from examples_prompt.trainers.trainer_utils import save_training_config
|
from examples_prompt.trainers.trainer_utils import save_training_config
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
@ -56,7 +57,8 @@ import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
|
TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
|
||||||
"cola": ['matthews_correlation'],
|
"cola": ['matthews_correlation'],
|
||||||
|
@ -109,7 +111,6 @@ class RemainArgHfArgumentParser(HfArgumentParser):
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
# or by passing the --help flag to this script.
|
# or by passing the --help flag to this script.
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
@ -202,6 +203,16 @@ def main():
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
model = AutoModelForMaskedLM.from_pretrained(
|
model = AutoModelForMaskedLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
|
@ -212,11 +223,6 @@ def main():
|
||||||
)
|
)
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
from openprompt.prompts import ManualTemplate
|
|
||||||
from openprompt.plms import MLMTokenizerWrapper
|
|
||||||
template = ManualTemplate(tokenizer, text="""sentence1: {"meta": 'premise'}, sentence2:,"""+
|
|
||||||
"""{"meta":"hypothesis", "shortenable":True}, The answer was {"mask"} .""")
|
|
||||||
tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method='tail')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,49 +254,185 @@ def main():
|
||||||
|
|
||||||
# Temporarily set max_target_length for training.
|
# Temporarily set max_target_length for training.
|
||||||
#max_target_length = data_args.max_target_length
|
#max_target_length = data_args.max_target_length
|
||||||
padding = "max_length" if data_args.pad_to_max_length else False
|
|
||||||
|
|
||||||
def preprocess_function(examples, **kwargs):
|
|
||||||
# max_target_length += 1
|
|
||||||
tokenizer = kwargs['tokenizer']
|
|
||||||
data_args = kwargs['data_args']
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("max_length", data_args.max_source_length)
|
|
||||||
model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
|
|
||||||
padding="max_length", truncation=True)
|
|
||||||
|
|
||||||
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
|
|
||||||
# model_inputs["mask_position"] = mask_position
|
|
||||||
model_inputs["extra_fields"] = examples['extra_fields']
|
|
||||||
# from IPython import embed; embed(header="Therer")
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
|
|
||||||
column_names = ['source', 'target', 'label', 'extra_fields']
|
column_names = ['source', 'target', 'label', 'extra_fields']
|
||||||
performance_metrics = {}
|
performance_metrics = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompts(task, tokenizer, predict_with_generate, template_id="0", verbalizer_id="0"):
|
||||||
|
# tid = getattr(config, "template_id", "0")
|
||||||
|
# vid = getattr(config, "verbalizer_id", "0")
|
||||||
|
from openpromptu.prompts import GenerationVerbalizer, ManualVerbalizer
|
||||||
|
from openpromptu.prompts import ManualTemplate
|
||||||
|
template = ManualTemplate(text = task.templates_text[template_id])
|
||||||
|
if predict_with_generate:
|
||||||
|
verbalizer = GenerationVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||||
|
else:
|
||||||
|
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||||
|
# max_target_length = self.get_max_target_length(self.default_max_length)
|
||||||
|
|
||||||
|
from openpromptu import TokenizerWrapper
|
||||||
|
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||||
|
return template, verbalizer, tokenizer_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
from openpromptu.data_utils import InputExample
|
||||||
|
|
||||||
|
max_target_length = 32
|
||||||
|
|
||||||
|
if os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
||||||
|
mask_token_func = lambda i: tokenizer.additional_special_tokens[i]
|
||||||
|
def preprocess_function(raw_example, **kwargs):
|
||||||
|
# max_target_length += 1
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
data_args = kwargs['data_args']
|
||||||
|
template = kwargs['template']
|
||||||
|
verbalizer = kwargs['verbalizer']
|
||||||
|
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||||
|
split = kwargs['split']
|
||||||
|
# extra_fileds = example['extra_fields']
|
||||||
|
|
||||||
|
example = InputExample(**raw_example)
|
||||||
|
|
||||||
|
# from collections import namedtuple
|
||||||
|
# example['tgt_text'] = ""
|
||||||
|
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
||||||
|
try:
|
||||||
|
example = verbalizer.wrap_one_example(example)
|
||||||
|
example, other = template.wrap_one_example(example)
|
||||||
|
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||||
|
model_inputs = tokenizer(input_sentence, max_length=256,
|
||||||
|
padding="max_length", truncation=True)
|
||||||
|
except:
|
||||||
|
from IPython import embed; embed(header="Therer")
|
||||||
|
|
||||||
|
|
||||||
|
# if split == "train":
|
||||||
|
with tokenizer.as_target_tokenizer():
|
||||||
|
label = tokenizer(other['tgt_text']).input_ids
|
||||||
|
# label = [l if l != tokenizer.pad_token_id else -100 for l in label]
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="Therer")
|
||||||
|
model_inputs["labels"] = label
|
||||||
|
# else:
|
||||||
|
# # from IPython import embed; embed(header="Therer")
|
||||||
|
# model_inputs["tgt_text"] = other['tgt_text']
|
||||||
|
# model_inputs['labels'] = None # model_inputs["extra_fields"] = extra_fileds
|
||||||
|
# from IPython import embed; embed(header="Therer2")
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds, tokenizer, dataset_name, eval_metric):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
preds, labels = eval_preds
|
||||||
|
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
# post_processor = .get(data_args.dataset_name[0], tokenizer,
|
||||||
|
# data_args.ignore_pad_token_for_loss)
|
||||||
|
# decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
|
||||||
|
result = {}
|
||||||
|
for metric in eval_metric:
|
||||||
|
result.update(metric(decoded_preds, decoded_labels))
|
||||||
|
|
||||||
|
average_metric = sum(result.values())/len(result)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
|
||||||
|
or os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
||||||
|
mask_token_func = lambda i: tokenizer.mask_token
|
||||||
|
def preprocess_function(raw_example, **kwargs):
|
||||||
|
# max_target_length += 1
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="Therer")
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
|
||||||
|
data_args = kwargs['data_args']
|
||||||
|
template = kwargs['template']
|
||||||
|
verbalizer = kwargs['verbalizer']
|
||||||
|
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||||
|
|
||||||
|
example = InputExample(**raw_example)
|
||||||
|
|
||||||
|
# from collections import namedtuple
|
||||||
|
# example['tgt_text'] = ""
|
||||||
|
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
||||||
|
# try:
|
||||||
|
# example = verbalizer.wrap_one_example(example)
|
||||||
|
example, other = template.wrap_one_example(example)
|
||||||
|
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||||
|
model_inputs = tokenizer(input_sentence, max_length=256,
|
||||||
|
padding="max_length", truncation=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# print("max_length", data_args.max_source_length)
|
||||||
|
# model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
|
||||||
|
# padding="max_length", truncation=True)
|
||||||
|
|
||||||
|
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
|
||||||
|
# model_inputs["mask_position"] = mask_position
|
||||||
|
# model_inputs["extra_fields"] = examples['extra_fields']
|
||||||
|
# from IPython import embed; embed(header="Therer")
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds, dataset_name):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
|
||||||
|
preds = np.argmax(preds, axis=-1)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
average_metrics = []
|
||||||
|
for metric in eval_metric:
|
||||||
|
metric_item = metric(preds, labels)
|
||||||
|
metric_value = list(metric_item.values())
|
||||||
|
result.update(metric_item)
|
||||||
|
average_metrics.extend(metric_value)
|
||||||
|
print("average:",average_metrics)
|
||||||
|
average_metric = sum(average_metrics)/len(average_metrics)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
|
|
||||||
train_task = AutoTask.get(data_args.task_name,
|
train_task = AutoTask.get(data_args.task_name,
|
||||||
data_args.dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
predict_with_generate=training_args.predict_with_generate,
|
# predict_with_generate=training_args.predict_with_generate,
|
||||||
seed=data_args.data_seed)
|
seed=data_args.data_seed)
|
||||||
|
|
||||||
train_dataset = train_task.get(split='train',
|
train_dataset = train_task.get(split='train',
|
||||||
split_validation_test=training_args.split_validation_test,
|
split_validation_test=training_args.split_validation_test,
|
||||||
n_obs=data_args.max_train_samples)
|
n_obs=data_args.max_train_samples)
|
||||||
|
|
||||||
|
template, verbalizer, tokenizer_wrapper = get_prompts(train_task, tokenizer, training_args.predict_with_generate)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
functools.partial(preprocess_function,
|
functools.partial(preprocess_function,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
tokenizer=tokenizer),
|
tokenizer=tokenizer,
|
||||||
batched=True,
|
template=template,
|
||||||
|
verbalizer=verbalizer,
|
||||||
|
tokenizer_wrapper=tokenizer_wrapper,
|
||||||
|
split="train"),
|
||||||
|
batched=False,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
@ -298,6 +440,8 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
eval_splits_names = []
|
eval_splits_names = []
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
|
@ -306,87 +450,49 @@ def main():
|
||||||
eval_splits_names.append("test")
|
eval_splits_names.append("test")
|
||||||
eval_splits = {}
|
eval_splits = {}
|
||||||
for split_name in eval_splits_names:
|
for split_name in eval_splits_names:
|
||||||
_tasks = {dataset_name: AutoTask.get(dataset_name,
|
eval_task = AutoTask.get(data_args.task_name,
|
||||||
dataset_config_name,
|
data_args.dataset_config_name,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
predict_with_generate=training_args.predict_with_generate,
|
# predict_with_generate=training_args.predict_with_generate,
|
||||||
seed=data_args.data_seed)
|
seed=data_args.data_seed)
|
||||||
for dataset_name, dataset_config_name\
|
# for dataset_name, dataset_config_name\
|
||||||
in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
|
# in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
|
||||||
|
|
||||||
_datasets = {dataset_name: task.get(split=split_name,
|
eval_dataset = eval_task.get(split=split_name,
|
||||||
split_validation_test=training_args.split_validation_test,
|
split_validation_test=training_args.split_validation_test,
|
||||||
n_obs=data_args.max_train_samples)
|
n_obs=data_args.max_train_samples)
|
||||||
for dataset_name, task in _tasks.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
_datasets = {dataset_name: d.map(
|
|
||||||
|
|
||||||
|
template, _verbalizer, tokenizer_wrapper = get_prompts(eval_task, tokenizer, training_args.predict_with_generate)
|
||||||
|
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
functools.partial(preprocess_function,
|
functools.partial(preprocess_function,
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
tokenizer=tokenizer),
|
tokenizer=tokenizer,
|
||||||
batched=True,
|
template=template,
|
||||||
|
verbalizer=_verbalizer,
|
||||||
|
tokenizer_wrapper=tokenizer_wrapper,
|
||||||
|
split=split_name),
|
||||||
|
batched=False,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
remove_columns=[x for x in d.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
remove_columns=[x for x in eval_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
for dataset_name, d in _datasets.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
eval_splits[split_name] = _datasets
|
|
||||||
|
eval_splits[split_name] = eval_dataset
|
||||||
if split_name == "test":
|
if split_name == "test":
|
||||||
eval_metrics = {dataset_name:task.metric for dataset_name, task in _tasks.items()}
|
eval_metric = eval_task.metric
|
||||||
verbalizers = {dataset_name:task.verbalizer for dataset_name, task in _tasks.items()}
|
verbalizer = _verbalizer
|
||||||
|
|
||||||
|
|
||||||
# Metric, we assume we have only one training task.
|
|
||||||
# eval_metrics = [task.metric for task in
|
|
||||||
# for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)][0]
|
|
||||||
|
|
||||||
# Extracts the extra information needed to evaluate on each dataset.
|
|
||||||
# These information are only used in the compute_metrics.
|
|
||||||
# We will assume that the test/eval dataloader does not change the order of
|
|
||||||
# the data.
|
|
||||||
# data_info = {"eval": eval_datasets[data_args.eval_dataset_name[0]]['extra_fields'],
|
|
||||||
# "test": test_datasets[data_args.test_dataset_name[0]]['extra_fields'],
|
|
||||||
# "train": train_dataset['extra_fields']}
|
|
||||||
def compute_metrics(eval_preds, dataset_name):
|
|
||||||
|
|
||||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
|
||||||
|
|
||||||
preds = np.argmax(preds, axis=-1)
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
average_metrics = []
|
|
||||||
for metric in eval_metrics[dataset_name]:
|
|
||||||
metric_item = metric(preds, labels)
|
|
||||||
metric_value = list(metric_item.values())
|
|
||||||
result.update(metric_item)
|
|
||||||
average_metrics.extend(metric_value)
|
|
||||||
print("average:",average_metrics)
|
|
||||||
average_metric = sum(average_metrics)/len(average_metrics)
|
|
||||||
result.update({"average_metrics":average_metric})
|
|
||||||
return result
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="isseq2seq")
|
|
||||||
# Initialize our Trainer
|
|
||||||
# if training_args.is_seq2seq == True:
|
|
||||||
# trainer = Seq2SeqTrainer(
|
|
||||||
# model=model,
|
|
||||||
# args=training_args,
|
|
||||||
# delta_args=delta_args,
|
|
||||||
# train_dataset=splits['train'] if training_args.do_train else None,
|
|
||||||
# eval_dataset=list(splits['validation'].values())[0] if training_args.do_eval else None,
|
|
||||||
# data_info = data_info,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# data_collator=data_collator,
|
|
||||||
# compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
|
||||||
# evaluation_metrics = TASK_TO_METRICS[data_args.dataset_name[0]],
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
|
|
||||||
class MLMTrainer(Trainer):
|
class MLMTrainer(Trainer):
|
||||||
_verbalizers = verbalizers
|
def __init__(self, verbalizer=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.verbalizer=verbalizer
|
||||||
|
|
||||||
# def training_step(self, model, inputs):
|
# def training_step(self, model, inputs):
|
||||||
# from IPython import embed; embed(header="in trainstep")
|
# from IPython import embed; embed(header="in trainstep")
|
||||||
|
@ -394,7 +500,7 @@ def main():
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
|
||||||
labels = inputs.pop('labels')
|
labels = inputs.pop('labels')
|
||||||
extra_fields = inputs.pop("extra_fields")
|
# extra_fields = inputs.pop("extra_fields")
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
input_ids = inputs['input_ids']
|
input_ids = inputs['input_ids']
|
||||||
|
@ -402,20 +508,8 @@ def main():
|
||||||
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="382")
|
# from IPython import embed; embed(header="382")
|
||||||
verbalizer = self._verbalizers[extra_fields[0]['dataset_name']].cuda()
|
verbalizer = self.verbalizer.cuda()
|
||||||
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
||||||
|
|
||||||
# colidx = torch.where(input_ids == verbalizer.tokenizer.mask_token_id)[0].cpu()
|
|
||||||
# print(colidx)
|
|
||||||
# missing = set([i for i in range(input_ids.size(0))]) - set(colidx.numpy())
|
|
||||||
# print(missing)
|
|
||||||
# if len(missing) > 0:
|
|
||||||
# print("missing")
|
|
||||||
# missing = list(missing)[0]
|
|
||||||
# input_ids_missing = input_ids[missing]
|
|
||||||
# print(input_ids_missing)
|
|
||||||
# missing_tokens = verbalizer.tokenizer.convert_ids_to_tokens(input_ids_missing)
|
|
||||||
# print(missing_tokens)
|
|
||||||
label_logits = verbalizer.process_logits(logits_at_mask)
|
label_logits = verbalizer.process_logits(logits_at_mask)
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
# from IPython import embed; embed(header="In compute loss")
|
# from IPython import embed; embed(header="In compute loss")
|
||||||
|
@ -423,21 +517,136 @@ def main():
|
||||||
outputs.logits = label_logits
|
outputs.logits = label_logits
|
||||||
return (loss, outputs) if return_outputs else loss
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
|
||||||
|
class MySeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
# from IPython import embed; embed(header="agag")
|
||||||
|
|
||||||
|
intlabel = inputs.pop('label')
|
||||||
|
# extra_fields = inputs.pop("extra_fields")
|
||||||
|
outputs = model(**inputs)
|
||||||
|
# logits = outputs.get("logits")
|
||||||
|
# input_ids = inputs['input_ids']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# # from IPython import embed; embed(header="382")
|
||||||
|
# verbalizer = self._verbalizers.cuda()
|
||||||
|
# logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
||||||
|
# label_logits = verbalizer.process_logits(logits_at_mask)
|
||||||
|
# loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
# # from IPython import embed; embed(header="In compute loss")
|
||||||
|
# loss = loss_fct(label_logits, labels)
|
||||||
|
# outputs.logits = label_logits
|
||||||
|
if return_outputs:
|
||||||
|
return (outputs.loss, outputs)
|
||||||
|
else:
|
||||||
|
return outputs.loss
|
||||||
|
|
||||||
|
|
||||||
|
# def evaluate(
|
||||||
|
# self,
|
||||||
|
# eval_dataset: Optional[Dict[str, Dataset]] = None,
|
||||||
|
# ignore_keys: Optional[List[str]] = None,
|
||||||
|
# metric_key_prefix: str = "eval",
|
||||||
|
# max_length: Optional[int] = None,
|
||||||
|
# num_beams: Optional[int] = None,
|
||||||
|
# ) -> Dict[str, float]:
|
||||||
|
# # TODO: this also needs to be set per dataset
|
||||||
|
# self._max_length = max_length
|
||||||
|
# self._num_beams = num_beams
|
||||||
|
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model, #nn.Module,
|
||||||
|
inputs, #Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
prediction_loss_only, #: bool,
|
||||||
|
ignore_keys, #: Optional[List[str]] = None,
|
||||||
|
): #-> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
prediction_loss_only (:obj:`bool`):
|
||||||
|
Whether or not to return the loss only.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||||
|
labels (each being optional).
|
||||||
|
"""
|
||||||
|
if not self.args.predict_with_generate or prediction_loss_only:
|
||||||
|
return super().prediction_step(
|
||||||
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
has_labels = "labels" in inputs
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
intlabel = inputs.pop('label')
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": 10, # self._max_length if s is not None else self.model.config.max_length,
|
||||||
|
"num_beams": 1 #self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
||||||
|
}
|
||||||
|
generated_tokens = self.model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if has_labels:
|
||||||
|
if self.label_smoother is not None:
|
||||||
|
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if self.args.prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
labels = inputs["labels"]
|
||||||
|
if labels.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="In seqseqtrainer")
|
||||||
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
|
# def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
|
||||||
# aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
# aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
||||||
# # from IPython import embed; embed()
|
# # from IPython import embed; embed()
|
||||||
# return aa
|
# return aa
|
||||||
from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
|
# from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
|
||||||
class DataCollatorWithExtraFields(DataCollatorMixin):
|
# class DataCollatorWithExtraFields(DataCollatorMixin):
|
||||||
return_tensors: str = "pt"
|
# return_tensors: str = "pt"
|
||||||
def torch_call(self, features):
|
# def torch_call(self, features):
|
||||||
print(len(features))
|
# # print(len(features))
|
||||||
extra_fields = [f.pop('extra_fields') for f in features]
|
# # extra_fields = [f.pop('extra_fields') for f in features]
|
||||||
batch = torch_default_data_collator(features)
|
# batch = torch_default_data_collator(features)
|
||||||
batch['extra_fields'] =extra_fields
|
# batch['extra_fields'] =extra_fields
|
||||||
print(batch['input_ids'].size())
|
# # print(batch['input_ids'].size())
|
||||||
print(batch['labels'].size())
|
# # print(batch['labels'].size())
|
||||||
return batch
|
# return batch
|
||||||
|
|
||||||
|
|
||||||
# from transformers.data.data_collator import DefaultDataCollator
|
# from transformers.data.data_collator import DefaultDataCollator
|
||||||
|
@ -455,15 +664,29 @@ def main():
|
||||||
|
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
|
|
||||||
|
if os.path.basename(model_args.model_name_or_path).startswith("roberta") or \
|
||||||
|
os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
||||||
trainer = MLMTrainer(
|
trainer = MLMTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset if training_args.do_train else None,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=eval_splits['eval'][data_args.task_name] if training_args.do_eval else None,
|
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
||||||
compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name),
|
compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name),
|
||||||
# tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=DataCollatorWithExtraFields(),
|
# data_collator=DataCollatorWithExtraFields(),
|
||||||
|
verbalizer=verbalizer,
|
||||||
)
|
)
|
||||||
|
elif os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
||||||
|
trainer = MySeq2SeqTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
|
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
||||||
|
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer, dataset_name=data_args.task_name, eval_metric=eval_metric),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Saves training config.
|
# Saves training config.
|
||||||
|
@ -522,24 +745,23 @@ def main():
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
for dataset_name, eval_dataset in eval_splits['eval'].items():
|
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_dataset,
|
metrics = trainer.evaluate(eval_dataset=eval_splits['eval'],
|
||||||
)
|
)
|
||||||
trainer.log_metrics(f"{dataset_name}_eval", metrics)
|
trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
|
||||||
trainer.save_metrics(f"{dataset_name}_eval", metrics)
|
trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
|
||||||
all_results['evaluate'][dataset_name] = metrics
|
all_results['evaluate'][data_args.task_name] = metrics
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
all_results['test'] = {}
|
all_results['test'] = {}
|
||||||
if training_args.do_test:
|
if training_args.do_test:
|
||||||
logger.info("*** Test ***")
|
logger.info("*** Test ***")
|
||||||
for dataset_name, test_dataset in eval_splits['test'].items():
|
metrics = trainer.evaluate(eval_dataset=eval_splits['test'],
|
||||||
metrics = trainer.evaluate(eval_dataset=test_dataset,
|
|
||||||
metric_key_prefix="test"
|
metric_key_prefix="test"
|
||||||
)
|
)
|
||||||
trainer.log_metrics(f"{dataset_name}_test", metrics)
|
trainer.log_metrics(f"{data_args.task_name}_test", metrics)
|
||||||
trainer.save_metrics(f"{dataset_name}_test", metrics)
|
trainer.save_metrics(f"{data_args.task_name}_test", metrics)
|
||||||
all_results['test'][dataset_name] = metrics
|
all_results['test'][data_args.task_name] = metrics
|
||||||
|
|
||||||
# repo_name = create_hub_repo_name(root="DeltaHub",
|
# repo_name = create_hub_repo_name(root="DeltaHub",
|
||||||
# dataset=data_args.task_name,
|
# dataset=data_args.task_name,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
cd configs/
|
|
||||||
python config_gen.py --job $3
|
python configs/config_gen.py --job $3
|
||||||
echo "Regenerate config"
|
echo "Regenerate config"
|
||||||
cd ../
|
|
||||||
files=(cola mnli mrpc qnli qqp rte sst2 stsb superglue-boolq superglue-cb superglue-copa superglue-multirc superglue-record superglue-wic superglue-wsc.fixed)
|
files=(cola mnli mrpc qnli qqp rte sst2 stsb superglue-boolq superglue-cb superglue-copa superglue-multirc superglue-record superglue-wic superglue-wsc.fixed)
|
||||||
for ((i=$1; i<=$2; i++))
|
for ((i=$1; i<=$2; i++))
|
||||||
do
|
do
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
|
||||||
__version__ = "0.0.1"
|
__version__ = "0.0.4"
|
||||||
|
|
||||||
class GlobalSetting:
|
class GlobalSetting:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -728,6 +728,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
elif _delta_info['method'] == "insert_sequential":
|
elif _delta_info['method'] == "insert_sequential":
|
||||||
self.insert_sequential_module(module=submodule,
|
self.insert_sequential_module(module=submodule,
|
||||||
_delta_info=_delta_info)
|
_delta_info=_delta_info)
|
||||||
|
elif _delta_info['method'] == "insert_parallel":
|
||||||
|
self.insert_parallel_module(module=submodule,
|
||||||
|
_delta_info=_delta_info)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -765,7 +768,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
submodule.forward = submodule.forward.__wrapped__
|
submodule.forward = submodule.forward.__wrapped__
|
||||||
delattr(submodule, _delta_info["delta_name"])
|
delattr(submodule, _delta_info["delta_name"])
|
||||||
else:
|
else:
|
||||||
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It'ss not a wrapped function.".format(name))
|
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
|
||||||
|
elif _delta_info['method'] == "insert_parallel":
|
||||||
|
if hasattr(submodule.forward, "__wrapped__"):
|
||||||
|
submodule.forward = submodule.forward.__wrapped__
|
||||||
|
delattr(submodule, _delta_info["delta_name"])
|
||||||
|
else:
|
||||||
|
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -776,4 +785,3 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ class LowRankLinear(nn.Module):
|
||||||
self.r = r
|
self.r = r
|
||||||
self.lora_alpha = lora_alpha
|
self.lora_alpha = lora_alpha
|
||||||
self.lora_dropout = lora_dropout
|
self.lora_dropout = lora_dropout
|
||||||
self.lin = nn.Linear(in_features, out_features) #
|
|
||||||
if lora_dropout > 0.:
|
if lora_dropout > 0.:
|
||||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||||
else:
|
else:
|
||||||
|
@ -37,7 +36,6 @@ class LowRankLinear(nn.Module):
|
||||||
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
||||||
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
||||||
self.scaling = self.lora_alpha / self.r
|
self.scaling = self.lora_alpha / self.r
|
||||||
self.lin.reset_parameters() #
|
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
nn.init.zeros_(self.lora_B)
|
nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from examples_prompt.metrics.metrics import exact_match
|
||||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||||
from opendelta.utils.name_based_addressing import *
|
from opendelta.utils.name_based_addressing import *
|
||||||
from opendelta.utils.cuda import get_device
|
from opendelta.utils.cuda import get_device
|
||||||
|
@ -47,6 +48,7 @@ class SoftPromptLayer(nn.Module):
|
||||||
soft_token_num: int = 100,
|
soft_token_num: int = 100,
|
||||||
raw_embedding: Optional[torch.Tensor] = None,
|
raw_embedding: Optional[torch.Tensor] = None,
|
||||||
init_range: Optional[float] = 0.5,
|
init_range: Optional[float] = 0.5,
|
||||||
|
other_expand_ids: Optional[Dict] = {"attention_mask":1, "token_type_ids":0},
|
||||||
token_init = False,
|
token_init = False,
|
||||||
pad_id = 0,
|
pad_id = 0,
|
||||||
device: Optional[str]=None,
|
device: Optional[str]=None,
|
||||||
|
@ -59,10 +61,13 @@ class SoftPromptLayer(nn.Module):
|
||||||
self.pad_id = pad_id
|
self.pad_id = pad_id
|
||||||
self.token_init = token_init
|
self.token_init = token_init
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.other_expand_ids = other_expand_ids
|
||||||
|
|
||||||
assert self.num_tokens>0
|
assert self.num_tokens>0
|
||||||
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
||||||
|
|
||||||
|
self.all_pseudo_tokens = {}
|
||||||
|
|
||||||
def pre_forward(self, *args, **kwargs):
|
def pre_forward(self, *args, **kwargs):
|
||||||
# if attention_mask is passed as PLM's input, modify it here
|
# if attention_mask is passed as PLM's input, modify it here
|
||||||
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
|
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
|
||||||
|
@ -100,8 +105,19 @@ class SoftPromptLayer(nn.Module):
|
||||||
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
|
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
|
||||||
kwargs['inputs_embeds'] = inputs_embeds
|
kwargs['inputs_embeds'] = inputs_embeds
|
||||||
|
|
||||||
am = kwargs['attention_mask']
|
for expand_key in self.other_expand_ids:
|
||||||
am.data = torch.cat([torch.ones((*am.shape[:-1], inputs_embeds.shape[-2]-am.shape[-1]), dtype = am.dtype,device=am.device), am], dim=-1)
|
if expand_key in kwargs:
|
||||||
|
real_tokens = kwargs[expand_key]
|
||||||
|
if expand_key in self.all_pseudo_tokens:
|
||||||
|
pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
|
||||||
|
else:
|
||||||
|
pseudo_tokens_value = self.other_expand_ids[expand_key]
|
||||||
|
pseudo_tokens = torch.ones(
|
||||||
|
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
|
||||||
|
dtype = real_tokens.dtype,
|
||||||
|
device=real_tokens.device) * pseudo_tokens_value
|
||||||
|
self.all_pseudo_tokens[expand_key] = pseudo_tokens
|
||||||
|
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
|
||||||
|
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
||||||
|
@ -136,6 +152,10 @@ class SoftPromptModel(DeltaBase):
|
||||||
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
|
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
|
||||||
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
|
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
|
||||||
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
|
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
|
||||||
|
other_expand_ids (:obj:`dict`, *optional*, default to `{"attention_mask":1, "token_type_ids":0}`) The name of
|
||||||
|
other tokens and its default value that expand along with the input sequence. For example, when
|
||||||
|
you prepend 100 tokens to the input_ids, the attention_mask should be extended, and the token_type_ids should
|
||||||
|
be extended as well.
|
||||||
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
||||||
the implemented ones)
|
the implemented ones)
|
||||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
||||||
|
@ -151,6 +171,7 @@ class SoftPromptModel(DeltaBase):
|
||||||
soft_token_num=100,
|
soft_token_num=100,
|
||||||
init_range = 0.5,
|
init_range = 0.5,
|
||||||
token_init=True,
|
token_init=True,
|
||||||
|
other_expand_ids={"attention_mask":1, "token_type_ids":0},
|
||||||
modified_modules: Optional[List[str]] = None,
|
modified_modules: Optional[List[str]] = None,
|
||||||
exclude_modules: Optional[List[str]] = None,
|
exclude_modules: Optional[List[str]] = None,
|
||||||
unfrozen_modules: Optional[List[str]] = None,
|
unfrozen_modules: Optional[List[str]] = None,
|
||||||
|
@ -202,6 +223,7 @@ class SoftPromptModel(DeltaBase):
|
||||||
soft_prompt_layer = SoftPromptLayer(
|
soft_prompt_layer = SoftPromptLayer(
|
||||||
soft_token_num = self.soft_token_num,
|
soft_token_num = self.soft_token_num,
|
||||||
raw_embedding = self.raw_embedding,
|
raw_embedding = self.raw_embedding,
|
||||||
|
other_expand_ids = self.other_expand_ids,
|
||||||
token_init = self.token_init,
|
token_init = self.token_init,
|
||||||
init_range = self.init_range,
|
init_range = self.init_range,
|
||||||
device = module_device,
|
device = module_device,
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -17,7 +17,7 @@ print(requires)
|
||||||
with open('README.md', 'r') as f:
|
with open('README.md', 'r') as f:
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name = 'opendelta',
|
name = 'opendelta',
|
||||||
version = '0.0.3',
|
version = '0.0.4',
|
||||||
description = "An open source framework for delta learning (parameter efficient learning).",
|
description = "An open source framework for delta learning (parameter efficient learning).",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|
Loading…
Reference in New Issue