example_prompt_unify
This commit is contained in:
parent
151e3dd2c2
commit
0de3cbd31d
Binary file not shown.
Binary file not shown.
|
@ -1,3 +1,3 @@
|
|||
from .tasks import TASK_MAPPING, AutoTask
|
||||
from .data_collator import TaskDataCollatorForSeq2Seq
|
||||
from .postprocessors import AutoPostProcessor
|
||||
# from .data_collator import TaskDataCollatorForSeq2Seq
|
||||
# from .postprocessors import AutoPostProcessor
|
||||
|
|
|
@ -1,28 +1,28 @@
|
|||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
# import numpy as np
|
||||
# from dataclasses import dataclass
|
||||
# from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
def check_uniqueness(self, samples):
|
||||
assert len(np.unique(samples)) == 1
|
||||
|
||||
def __call__(self, features):
|
||||
# tasks = [d.pop('task') for d in features]
|
||||
# self.check_uniqueness(tasks)
|
||||
output = super().__call__(features)
|
||||
# output["task"] = tasks[0]
|
||||
return output
|
||||
|
||||
# class CustomDataCollator(DefaultDataCollator):
|
||||
# @dataclass
|
||||
# class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
# def check_uniqueness(self, samples):
|
||||
# assert len(np.unique(samples)) == 1
|
||||
|
||||
# def __call__(self, features):
|
||||
# mask_positions = [d.pop('mask_positions') for d in features]
|
||||
# # tasks = [d.pop('task') for d in features]
|
||||
# # self.check_uniqueness(tasks)
|
||||
# output = super().__call__(features)
|
||||
|
||||
# # output["task"] = tasks[0]
|
||||
# return output
|
||||
|
||||
# # class CustomDataCollator(DefaultDataCollator):
|
||||
# # def check_uniqueness(self, samples):
|
||||
# # assert len(np.unique(samples)) == 1
|
||||
|
||||
# # def __call__(self, features):
|
||||
# # mask_positions = [d.pop('mask_positions') for d in features]
|
||||
# # # self.check_uniqueness(tasks)
|
||||
# # output = super().__call__(features)
|
||||
|
||||
# # # output["task"] = tasks[0]
|
||||
# # return output
|
|
@ -15,6 +15,7 @@ import re
|
|||
from openprompt.prompts import ManualTemplate, ManualVerbalizer
|
||||
from openprompt.plms.utils import TokenizerWrapper
|
||||
from openprompt.data_utils import InputExample
|
||||
from openprompt.prompts import GenerationVerbalizer
|
||||
import itertools
|
||||
|
||||
|
||||
|
@ -28,184 +29,170 @@ from typing import List, Dict
|
|||
from collections import defaultdict
|
||||
from openprompt.utils import round_list
|
||||
import warnings
|
||||
class MLMTokenizerWrapper:
|
||||
def __init__(self, max_seq_length, tokenizer, truncate_method):
|
||||
self.max_seq_length=max_seq_length
|
||||
self.tokenizer=tokenizer
|
||||
self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
|
||||
# from IPython import embed; embed(header="Truega")
|
||||
self.truncate_method=truncate_method
|
||||
self.total_passed_sentences = 0
|
||||
self.num_truncated_sentences = 0
|
||||
if truncate_method=='tail':
|
||||
self.truncate_fct = self.truncate_from_tail
|
||||
elif truncate_method=='head':
|
||||
self.truncate_fct = self.truncate_from_head
|
||||
elif truncate_method == 'balanced':
|
||||
self.truncate_fct = self.balanced_truncate
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# class MLMTokenizerWrapper:
|
||||
# def __init__(self, max_seq_length, tokenizer, truncate_method, mask_token_func=lambda i: "<mask>"):
|
||||
# self.max_seq_length=max_seq_length
|
||||
# self.tokenizer=tokenizer
|
||||
# self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
|
||||
# # from IPython import embed; embed(header="Truega")
|
||||
# self.truncate_method=truncate_method
|
||||
# self.total_passed_sentences = 0
|
||||
# self.num_truncated_sentences = 0
|
||||
# self.mask_token_func = mask_token_func
|
||||
|
||||
# if truncate_method=='tail':
|
||||
# self.truncate_fct = self.truncate_from_tail
|
||||
# elif truncate_method=='head':
|
||||
# self.truncate_fct = self.truncate_from_head
|
||||
# elif truncate_method == 'balanced':
|
||||
# self.truncate_fct = self.balanced_truncate
|
||||
# else:
|
||||
# raise NotImplementedError
|
||||
|
||||
|
||||
# def merge_wrapped_example(self, wrapped_example,):
|
||||
# ''' # TODO doens't consider the situation that input has two parts
|
||||
# '''
|
||||
|
||||
# wrapped_example
|
||||
|
||||
# # for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
|
||||
# # the input. Or in generation tasks, we need to generate a piece of target_text.
|
||||
# # In these case, it tokenized to the encoded_tgt_text for furture use.
|
||||
|
||||
|
||||
|
||||
# encoder_inputs = defaultdict(list)
|
||||
# # from IPython import embed; embed(header="Line 67")
|
||||
|
||||
# mask_count = 0
|
||||
# for piece in wrapped_example:
|
||||
# if piece['text'] == "<mask>":
|
||||
# encode_text = self.tokenizer.encode(self.mask_token_func(mask_count), add_special_tokens=False, return_special_tokens_mask=True )
|
||||
# mask_count += 1
|
||||
# else:
|
||||
# encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
|
||||
# encoder_inputs['input_ids'].append(encode_text)
|
||||
# encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
|
||||
|
||||
|
||||
def merge_wrapped_example(self, wrapped_example, ):
|
||||
''' # TODO doens't consider the situation that input has two parts
|
||||
'''
|
||||
# encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
|
||||
# encoder_inputs.pop("shortenable_ids")
|
||||
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
||||
# decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
|
||||
|
||||
wrapped_example, others = wrapped_example
|
||||
# return decoded_inputs
|
||||
|
||||
# for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
|
||||
# the input. Or in generation tasks, we need to generate a piece of target_text.
|
||||
# In these case, it tokenized to the encoded_tgt_text for furture use.
|
||||
|
||||
# @staticmethod
|
||||
# def balanced_truncate(input_dict: Dict,
|
||||
# num_tokens_to_truncate: int=0) -> Dict:
|
||||
# '''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
|
||||
# '''
|
||||
# shortenable_lens = [len(parts) if parts[0]==1 else 0
|
||||
# for parts in input_dict['shortenable_ids']]
|
||||
# total_shortenable_len = sum(shortenable_lens)
|
||||
# num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
|
||||
# for part_len in shortenable_lens]
|
||||
# round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
|
||||
|
||||
# truncated_example = defaultdict(list)
|
||||
# for key in input_dict:
|
||||
# parts = input_dict[key]
|
||||
# for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
|
||||
# truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
|
||||
# return truncated_example
|
||||
|
||||
# @staticmethod
|
||||
# def truncate_from_tail(input_dict: Dict,
|
||||
# num_tokens_to_truncate: int=0) -> Dict:
|
||||
# r"""truncate the inputs from the rear
|
||||
# """
|
||||
# truncated_example = defaultdict(list)
|
||||
# shortenable_ids = input_dict['shortenable_ids']
|
||||
|
||||
# for key in input_dict:
|
||||
# parts = input_dict[key]
|
||||
# to_trunc = num_tokens_to_truncate
|
||||
# for i, part in enumerate(parts[::-1]):
|
||||
# if len(part) == 0: # to prevent some part are empty after tokenization
|
||||
# continue
|
||||
# if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
|
||||
# continue
|
||||
# parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||
# to_trunc -= len(part)
|
||||
# if to_trunc <= 0:
|
||||
# break
|
||||
# truncated_example[key] = parts
|
||||
# return truncated_example
|
||||
|
||||
# @staticmethod
|
||||
# def truncate_from_head(input_dict: Dict,
|
||||
# num_tokens_to_truncate: int=0) -> Dict:
|
||||
# r"""truncate the inputs from the head
|
||||
# """
|
||||
# truncated_example = defaultdict(list)
|
||||
# shortenable_ids = input_dict['shortenable_ids']
|
||||
# for key in input_dict:
|
||||
# parts = input_dict[key]
|
||||
# to_trunc = num_tokens_to_truncate
|
||||
# for i, part in enumerate(parts):
|
||||
# if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
|
||||
# continue
|
||||
# parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||
# to_trunc -= len(part)
|
||||
# if to_trunc <= 0:
|
||||
# break
|
||||
# truncated_example[key] = parts
|
||||
# return truncated_example
|
||||
|
||||
# @staticmethod
|
||||
# def concate_parts(input_dict: Dict) -> Dict:
|
||||
# for key in input_dict:
|
||||
# input_dict[key] = list(itertools.chain(*input_dict[key]))
|
||||
# return input_dict
|
||||
|
||||
|
||||
# def truncate(self, encoder_inputs):
|
||||
# total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
|
||||
# num_specials = self.num_special_tokens_to_add
|
||||
# # print("num_specials", num_specials)
|
||||
# num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
|
||||
# self.total_passed_sentences+=1
|
||||
# if num_tokens_to_truncate>0:
|
||||
# self.num_truncated_sentences += 1
|
||||
# if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
|
||||
# raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
|
||||
# encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
|
||||
# num_tokens_to_truncate=num_tokens_to_truncate)
|
||||
# return encoder_inputs
|
||||
|
||||
# def tokenizer_preprocessor(self, example):
|
||||
# # source, target = example
|
||||
# # from IPython import embed; embed(header="Trehre2")
|
||||
# label = example['label']
|
||||
# guid = example['idx']
|
||||
# meta = dict(example)
|
||||
# meta.pop("label")
|
||||
# meta.pop("idx")
|
||||
|
||||
|
||||
|
||||
encoder_inputs = defaultdict(list)
|
||||
for piece in wrapped_example:
|
||||
encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
|
||||
encoder_inputs['input_ids'].append(encode_text)
|
||||
encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
|
||||
# # from IPython import embed; embed(header="Trehre2")
|
||||
|
||||
# e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
|
||||
|
||||
encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
|
||||
encoder_inputs.pop("shortenable_ids")
|
||||
encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
||||
decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
|
||||
|
||||
# again_encode = self.tokenizer.encode(decoded_inputs, add_special_tokens=False, return_special_tokens_mask=True)
|
||||
# if len(again_encode)> self.max_seq_length - 2:
|
||||
# print("length exceed!")
|
||||
# print(wrapped_example)
|
||||
# print(encoder_inputs['input_ids'])
|
||||
# print(again_encode)
|
||||
# print(decoded_inputs)
|
||||
# exit()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# delete shortenable ids
|
||||
|
||||
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
||||
# encoder_inputs = self.add_special_tokens(encoder_inputs=encoder_inputs)
|
||||
# # create special input ids
|
||||
# encoder_inputs['attention_mask'] = [1] *len(encoder_inputs['input_ids'])
|
||||
# # padding
|
||||
# encoder_inputs = self.padding(input_dict=encoder_inputs, max_len=self.max_seq_length, pad_id_for_inputs=self.tokenizer.pad_token_id)
|
||||
|
||||
return decoded_inputs
|
||||
|
||||
@staticmethod
|
||||
def balanced_truncate(input_dict: Dict,
|
||||
num_tokens_to_truncate: int=0) -> Dict:
|
||||
'''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
|
||||
'''
|
||||
shortenable_lens = [len(parts) if parts[0]==1 else 0
|
||||
for parts in input_dict['shortenable_ids']]
|
||||
total_shortenable_len = sum(shortenable_lens)
|
||||
num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
|
||||
for part_len in shortenable_lens]
|
||||
round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
|
||||
|
||||
truncated_example = defaultdict(list)
|
||||
for key in input_dict:
|
||||
parts = input_dict[key]
|
||||
for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
|
||||
truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
|
||||
return truncated_example
|
||||
|
||||
@staticmethod
|
||||
def truncate_from_tail(input_dict: Dict,
|
||||
num_tokens_to_truncate: int=0) -> Dict:
|
||||
r"""truncate the inputs from the rear
|
||||
"""
|
||||
truncated_example = defaultdict(list)
|
||||
shortenable_ids = input_dict['shortenable_ids']
|
||||
|
||||
for key in input_dict:
|
||||
parts = input_dict[key]
|
||||
to_trunc = num_tokens_to_truncate
|
||||
for i, part in enumerate(parts[::-1]):
|
||||
if len(part) == 0: # to prevent some part are empty after tokenization
|
||||
continue
|
||||
if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
|
||||
continue
|
||||
parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||
to_trunc -= len(part)
|
||||
if to_trunc <= 0:
|
||||
break
|
||||
truncated_example[key] = parts
|
||||
return truncated_example
|
||||
|
||||
@staticmethod
|
||||
def truncate_from_head(input_dict: Dict,
|
||||
num_tokens_to_truncate: int=0) -> Dict:
|
||||
r"""truncate the inputs from the head
|
||||
"""
|
||||
truncated_example = defaultdict(list)
|
||||
shortenable_ids = input_dict['shortenable_ids']
|
||||
for key in input_dict:
|
||||
parts = input_dict[key]
|
||||
to_trunc = num_tokens_to_truncate
|
||||
for i, part in enumerate(parts):
|
||||
if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
|
||||
continue
|
||||
parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
|
||||
to_trunc -= len(part)
|
||||
if to_trunc <= 0:
|
||||
break
|
||||
truncated_example[key] = parts
|
||||
return truncated_example
|
||||
|
||||
@staticmethod
|
||||
def concate_parts(input_dict: Dict) -> Dict:
|
||||
for key in input_dict:
|
||||
input_dict[key] = list(itertools.chain(*input_dict[key]))
|
||||
return input_dict
|
||||
|
||||
# @staticmethod
|
||||
# def padding(input_dict: Dict,
|
||||
# max_len: int, pad_id_for_inputs: int=0, pad_id_for_others: int=0) -> None:
|
||||
# for key, value in input_dict.items():
|
||||
# if (len(input_dict[key]) > max_len):
|
||||
# raise ValueError(f'''
|
||||
# Truncated seq length of '{key}' still greater than max length '{max_len}.'
|
||||
# One possible reason is that no enough shortenable parts in template. Try add {{"shortenable": "True"}} property.
|
||||
# ''')
|
||||
# if 'input' in key:
|
||||
# input_dict[key].extend([pad_id_for_inputs]*(max_len-len(value)))
|
||||
# else:
|
||||
# input_dict[key].extend([pad_id_for_others]*(max_len-len(value)))
|
||||
# return input_dict
|
||||
|
||||
|
||||
# def add_special_tokens(self, encoder_inputs):
|
||||
# # add special tokens
|
||||
# for key in encoder_inputs:
|
||||
# if key == "input_ids":
|
||||
# with warnings.catch_warnings():
|
||||
# warnings.simplefilter("ignore")
|
||||
# encoder_inputs[key] = self.tokenizer.build_inputs_with_special_tokens(
|
||||
# encoder_inputs[key])
|
||||
# return encoder_inputs
|
||||
|
||||
def truncate(self, encoder_inputs):
|
||||
total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
|
||||
num_specials = self.num_special_tokens_to_add
|
||||
# print("num_specials", num_specials)
|
||||
num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
|
||||
self.total_passed_sentences+=1
|
||||
if num_tokens_to_truncate>0:
|
||||
self.num_truncated_sentences += 1
|
||||
if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
|
||||
raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
|
||||
encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
|
||||
num_tokens_to_truncate=num_tokens_to_truncate)
|
||||
return encoder_inputs
|
||||
# if self.predict_with_generate:
|
||||
# e = self.verbalizer.wrap_one_example(e)
|
||||
# example_wrapped = self.template.wrap_one_example(e)
|
||||
# encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(example_wrapped)
|
||||
# print(encoded_sentence)
|
||||
# if self.predict_with_generate:
|
||||
# # return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
|
||||
# return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
|
||||
# else:
|
||||
# return {"source": encoded_sentence, "label": label, 'target': e.target_text, 'extra_fields':{'dataset_name':self.name}}
|
||||
|
||||
|
||||
|
||||
|
@ -234,46 +221,23 @@ class AbstractTask(abc.ABC):
|
|||
"superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"]
|
||||
large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"]
|
||||
|
||||
def __init__(self, config, data_args, tokenizer, predict_with_generate, seed=42, default_max_length=1):
|
||||
def __init__(self, config, data_args, seed=42, default_max_length=1):
|
||||
self.config = config
|
||||
self.seed = seed
|
||||
self.data_args = data_args
|
||||
self.tokenizer = tokenizer
|
||||
self.predict_with_generate = predict_with_generate
|
||||
# self.tokenizer = tokenizer
|
||||
# self.predict_with_generate = predict_with_generate
|
||||
self.default_max_length = default_max_length
|
||||
self.truncate_method = getattr(data_args, "truncate_method", "balanced")
|
||||
|
||||
tid = getattr(config, "template_id", 0)
|
||||
vid = getattr(config, "verbalizer_id", 0)
|
||||
|
||||
self.template = ManualTemplate(tokenizer=self.tokenizer, text = self.templates_text[tid])
|
||||
self.verbalizer = ManualVerbalizer(tokenizer=self.tokenizer, classes = self.labels_list, label_words=self.verbalizers[vid])
|
||||
|
||||
# if self.predict_with_generate:
|
||||
# self.reverse_verbalizer = {(int(x) for x in self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))): label for label in self.labels_list}
|
||||
# else:
|
||||
# self.reverse_verbalizer = {int(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.verbalizer[label]))[0]): label for label in self.labels_list}
|
||||
|
||||
self.tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=self.data_args.max_source_length, tokenizer=self.tokenizer, truncate_method=self.truncate_method)
|
||||
|
||||
generation_paradigm = getattr(config, "generation_paradigm", True)
|
||||
# generation_paradigm = getattr(config, "generation_paradigm", True)
|
||||
# self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm)
|
||||
self.max_target_length = self.get_max_target_length(self.default_max_length)
|
||||
|
||||
def get_max_target_length(self, default_max_length):
|
||||
if self.predict_with_generate:
|
||||
return max([len(label) for key, label in self.verbalizer.label_words_ids.items()])
|
||||
else:
|
||||
return default_max_length
|
||||
|
||||
def seq2seq_format(self, source, target, extra_fields={}
|
||||
):
|
||||
|
||||
return {'source': ' '.join(source),
|
||||
'target': ' '.join(target),
|
||||
'task': self.name,
|
||||
'extra_fields': extra_fields
|
||||
}
|
||||
# def get_max_target_length(self, default_max_length):
|
||||
# if self.predict_with_generate:
|
||||
# return -1
|
||||
# else:
|
||||
# return default_max_length
|
||||
|
||||
def check_n_obs(self, n_obs, total_size):
|
||||
if n_obs is not None and n_obs > total_size:
|
||||
|
@ -312,37 +276,9 @@ class AbstractTask(abc.ABC):
|
|||
else:
|
||||
return indices[validation_size:]
|
||||
|
||||
|
||||
def map_dataset(self, dataset):
|
||||
# from IPython import embed; embed(header="in get target length")
|
||||
return dataset.map(self.preprocessor).map(self.tokenizer_preprocessor)
|
||||
|
||||
def preprocessor(self, example):
|
||||
return example
|
||||
|
||||
def tokenizer_preprocessor(self, example):
|
||||
# source, target = example
|
||||
# from IPython import embed; embed(header="Trehre2")
|
||||
label = example['label']
|
||||
guid = example['idx']
|
||||
meta = dict(example)
|
||||
meta.pop("label")
|
||||
meta.pop("idx")
|
||||
|
||||
|
||||
|
||||
# from IPython import embed; embed(header="Trehre2")
|
||||
|
||||
e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
|
||||
template_e = self.template.wrap_one_example(e)
|
||||
encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(template_e)
|
||||
if self.predict_with_generate:
|
||||
# return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
|
||||
|
||||
|
||||
def get(self, split, n_obs=None, split_validation_test=False):
|
||||
# For small datasets (n_samples < 10K) without test set, we divide validation set to
|
||||
# half, use one half as test set and one half as validation set.
|
||||
|
@ -368,7 +304,7 @@ class AbstractTask(abc.ABC):
|
|||
# shuffles the data and samples it.
|
||||
if n_obs is not None:
|
||||
dataset = self.subsample(dataset, n_obs)
|
||||
return self.map_dataset(dataset)
|
||||
return dataset.map(self.preprocessor)
|
||||
|
||||
class Squad(AbstractTask):
|
||||
name = "squad"
|
||||
|
@ -387,25 +323,7 @@ class Squad(AbstractTask):
|
|||
return self.seq2seq_format(source, target, add_prefix)
|
||||
|
||||
|
||||
class MRPC(AbstractTask):
|
||||
name = "mrpc"
|
||||
labels_list = ["0", "1"]
|
||||
metric = [metrics.f1_score, metrics.accuracy]
|
||||
metric_names = ["f1", "accuracy"]
|
||||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
def load_dataset(self, split):
|
||||
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
|
||||
|
||||
# def preprocessor(self, example, add_prefix=True):
|
||||
# src_texts = ["sentence1:", example['sentence1'],
|
||||
# "sentence2:", example["sentence2"]]
|
||||
# tgt_texts = [str(example['label'])]
|
||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
##GLUE
|
||||
class COLA(AbstractTask):
|
||||
name = "cola"
|
||||
labels_list = ["0", "1"]
|
||||
|
@ -415,15 +333,20 @@ class COLA(AbstractTask):
|
|||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
templates_text = {"0": """sentence: {"meta": 'sentence', "shortenable":True} Are there any error in the sentence? {"mask"}""",
|
||||
}
|
||||
|
||||
verbalizers = {
|
||||
"0":{ "0": "yes", "1": "no"}
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.cola")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'cola',
|
||||
split=split, script_version="master")
|
||||
|
||||
# def preprocessor(self, example, add_prefix=True):
|
||||
# src_texts = ["sentence:", example['sentence']]
|
||||
# tgt_texts = [str(example['label'])]
|
||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
class SST2(AbstractTask):
|
||||
name = "sst2"
|
||||
|
@ -434,38 +357,50 @@ class SST2(AbstractTask):
|
|||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
verbalizers = [
|
||||
verbalizers = {
|
||||
"0":{"0":"negative","1":"positive"}
|
||||
}
|
||||
|
||||
]
|
||||
templates_text = {
|
||||
"0":"""The sentiment of sentence: "{"meta":"sentence", "shortenable":True} is {"mask"}."""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.sst2")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'sst2',
|
||||
split=split, script_version="master")
|
||||
|
||||
def preprocessor(self, example, add_prefix=True):
|
||||
src_texts = ["sentence:", example['sentence']]
|
||||
tgt_texts = [str(example['label'])]
|
||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
class STSB(AbstractTask):
|
||||
name = "stsb"
|
||||
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)]
|
||||
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
|
||||
metric_names = ["pearson", "spearmanr"]
|
||||
class MRPC(AbstractTask):
|
||||
name = "mrpc"
|
||||
labels_list = ["0", "1"]
|
||||
metric = [metrics.f1_score, metrics.accuracy]
|
||||
metric_names = ["f1", "accuracy"]
|
||||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
def load_dataset(self, split):
|
||||
return datasets.load_dataset('glue', 'stsb',
|
||||
split=split, script_version="master")
|
||||
|
||||
def preprocessor(self, example, add_prefix=True):
|
||||
src_texts = ["sentence1:", example['sentence1'],
|
||||
"sentence2:", example["sentence2"]]
|
||||
tgt_texts = [str(round_stsb_target(example['label']))]
|
||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
templates_text = {
|
||||
"0": """sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2: {"meta":"sentence2", "shortenable":True}. Are sentence1 and sentence2 equivalent? {"mask"}.""",
|
||||
}
|
||||
|
||||
verbalizers = {
|
||||
"0":{"0": "no","1": "yes"}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mrpc")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'mrpc', split=split, script_version="master")
|
||||
|
||||
|
||||
|
||||
class QQP(AbstractTask):
|
||||
|
@ -477,14 +412,46 @@ class QQP(AbstractTask):
|
|||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
templates_text = {"0":
|
||||
"""question1: {"meta": 'question1', "shortenable":True}. question2: {"meta": 'question2', "shortenable":True} Are question1 and question2 equivalent? {"mask"}."""
|
||||
}
|
||||
|
||||
verbalizers = {
|
||||
"0":{"0": "no","1": "yes"}
|
||||
}
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qqp")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'qqp',
|
||||
split=split, script_version="master")
|
||||
|
||||
|
||||
|
||||
class STSB(AbstractTask):
|
||||
name = "stsb"
|
||||
labels_list = [str(np.round(label, decimals=1)) for label in np.arange(0, 5.2, 0.2)]
|
||||
metric = [metrics.pearson_corrcoef, metrics.spearman_corrcoef]
|
||||
metric_names = ["pearson", "spearmanr"]
|
||||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
|
||||
verbalizers = {
|
||||
""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
return datasets.load_dataset('glue', 'stsb',
|
||||
split=split, script_version="master")
|
||||
|
||||
def preprocessor(self, example, add_prefix=True):
|
||||
src_texts = ["question1:", example['question1'],
|
||||
"question2:", example["question2"]]
|
||||
tgt_texts = [str(example['label'])]
|
||||
src_texts = ["sentence1:", example['sentence1'],
|
||||
"sentence2:", example["sentence2"]]
|
||||
tgt_texts = [str(round_stsb_target(example['label']))]
|
||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
|
@ -498,14 +465,29 @@ class MNLI(AbstractTask):
|
|||
metric_names = ["accuracy"]
|
||||
|
||||
|
||||
templates_text = {
|
||||
"0":"""premise: {"meta": 'premise', "shortenable":True}. hypothesis: {"meta": 'hypothesis', "shortenable":True} Does the premise entails the hypothesis? {"mask"}.""",
|
||||
}
|
||||
|
||||
verbalizers = {
|
||||
"0":{
|
||||
"0": "yes",
|
||||
"1": "neutral",
|
||||
"2": "no",
|
||||
}
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.mnli")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'mnli', split=split, script_version="master")
|
||||
|
||||
def preprocessor(self, example, add_prefix=True):
|
||||
src_texts = ["premise:", example['premise'],
|
||||
"hypothesis", example["hypothesis"]]
|
||||
tgt_texts = [str(example['label'])]
|
||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
# def preprocessor(self, example, add_prefix=True):
|
||||
# src_texts = ["premise:", example['premise'],
|
||||
# "hypothesis", example["hypothesis"]]
|
||||
# tgt_texts = [str(example['label'])]
|
||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
class QNLI(AbstractTask):
|
||||
|
@ -517,14 +499,33 @@ class QNLI(AbstractTask):
|
|||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
templates_text = {
|
||||
"0": """premise: {"meta": 'sentence', "shortenable":True}. hypothesis: {"meta": 'question', "shortenable":True}"""+
|
||||
"""Does the premise entails the hypothesis? {"mask"}.""",
|
||||
}
|
||||
|
||||
verbalizers = {
|
||||
"0":{
|
||||
"0": "yes",
|
||||
"1": "no",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/glue.qnli")[split]
|
||||
else:
|
||||
return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
||||
|
||||
def preprocessor(self, example, add_prefix=True):
|
||||
src_texts = ["question:", example['question'],
|
||||
"sentence:", example["sentence"]]
|
||||
tgt_texts = [str(example['label'])]
|
||||
return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
# def load_dataset(self, split):
|
||||
# return datasets.load_dataset('glue', 'qnli', split=split, script_version="master")
|
||||
|
||||
# def preprocessor(self, example, add_prefix=True):
|
||||
# src_texts = ["question:", example['question'],
|
||||
# "sentence:", example["sentence"]]
|
||||
# tgt_texts = [str(example['label'])]
|
||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
#Tested
|
||||
class RTE(AbstractTask):
|
||||
|
@ -537,15 +538,15 @@ class RTE(AbstractTask):
|
|||
"test": "validation"}
|
||||
|
||||
|
||||
templates_text = [
|
||||
"""sentence1: {"meta": 'sentence1', "shortenable":True}. sentence2:,"""+
|
||||
"""{"meta":"sentence2", "shortenable":True}. The answer was {"mask"}.""",
|
||||
]
|
||||
templates_text = {
|
||||
"0": """sentence1: {"meta": 'sentence1', "shortenable":True} sentence2: {"meta":"sentence2", "shortenable":True} The answer was {"mask"}.""",
|
||||
}
|
||||
|
||||
verbalizers = [{
|
||||
"0": "yes",
|
||||
verbalizers = {
|
||||
"0":{"0": "yes",
|
||||
"1": "no"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
|
@ -555,38 +556,6 @@ class RTE(AbstractTask):
|
|||
split=split, script_version="master")
|
||||
|
||||
|
||||
#Tested
|
||||
class SuperGLUEBoolQ(AbstractTask):
|
||||
name="superglue-boolq"
|
||||
labels_list = ['0', '1']
|
||||
metric = [metrics.accuracy]
|
||||
metric_names = ["accuracy"]
|
||||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
verbalizers = [
|
||||
{
|
||||
"0": "no",
|
||||
"1": "yes"
|
||||
},
|
||||
]
|
||||
mlmhead_verbalizers = {
|
||||
"0": "no",
|
||||
"1": "yes"
|
||||
}
|
||||
templates_text = [
|
||||
"""hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
|
||||
]
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
|
||||
|
||||
|
||||
|
||||
|
||||
class WNLI(AbstractTask):
|
||||
name = "wnli"
|
||||
|
@ -597,13 +566,13 @@ class WNLI(AbstractTask):
|
|||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
verbalizers = [{
|
||||
"0": "True",
|
||||
verbalizers = {
|
||||
"0":{"0": "True",
|
||||
"1": "False",
|
||||
}]
|
||||
templates_text = [
|
||||
"""{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}."""
|
||||
]
|
||||
}
|
||||
}
|
||||
templates_text = {"0": """{"meta": 'sentence1',"shortenable":True} Does it mean the following: "{"meta":'sentence2'}"? {"mask"}."""
|
||||
}
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
|
@ -613,6 +582,34 @@ class WNLI(AbstractTask):
|
|||
return datasets.load_dataset('glue', 'wnli', split=split, script_version="master")
|
||||
|
||||
|
||||
#SuperGLUE
|
||||
class SuperGLUEBoolQ(AbstractTask):
|
||||
name="superglue-boolq"
|
||||
labels_list = ['0', '1']
|
||||
metric = [metrics.accuracy]
|
||||
metric_names = ["accuracy"]
|
||||
split_to_data_split = {"train": "train",
|
||||
"validation": "validation",
|
||||
"test": "validation"}
|
||||
|
||||
verbalizers = {
|
||||
"0": {
|
||||
"0": "no",
|
||||
"1": "yes"
|
||||
},
|
||||
}
|
||||
|
||||
templates_text = {
|
||||
"0": """hypothesis: {"meta": "question", "shortenable":True} premise: {"meta":"passage", "shortenable":True} The answer was {"mask"}."""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.boolq")[split]
|
||||
else:
|
||||
return datasets.load_dataset('super_glue', 'boolq', split=split, script_version="master")
|
||||
|
||||
|
||||
#
|
||||
class SuperGLUECB(AbstractTask):
|
||||
name = "superglue-cb"
|
||||
|
@ -623,14 +620,15 @@ class SuperGLUECB(AbstractTask):
|
|||
metric = [metrics.mean_multiclass_f1(num_classes=3), metrics.accuracy]
|
||||
metric_names = ["f1_multiclass", "accuracy"]
|
||||
|
||||
verbalizers = [{
|
||||
"0": "yes",
|
||||
verbalizers = {
|
||||
"0":{"0": "yes",
|
||||
"1": "no",
|
||||
"2": "maybe"
|
||||
}]
|
||||
templates_text = [
|
||||
"""hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}."""
|
||||
]
|
||||
}
|
||||
}
|
||||
templates_text = {
|
||||
"0": """hypothesis: {"meta": 'hypothesis',"shortenable":True} premise: {"meta":'premise', "shortenable":True} The answer was {"mask"}."""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
|
@ -648,13 +646,15 @@ class SuperGLUECOPA(AbstractTask):
|
|||
metric = [metrics.accuracy]
|
||||
metric_names = ["accuracy"]
|
||||
|
||||
verbalizers = [{
|
||||
verbalizers = {
|
||||
"0":{
|
||||
"0": "1",
|
||||
"1": "2",
|
||||
}]
|
||||
templates_text = [
|
||||
"""choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}."""
|
||||
]
|
||||
}
|
||||
}
|
||||
templates_text = {
|
||||
"0": """choice1: {"meta":"choice1"} choice2: {"meta":"choice2"} premise: {"meta":"premise", "shortenable":True} The {"meta":"question"} answer was choice{"mask"}."""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
|
@ -673,19 +673,16 @@ class SuperGLUEMultiRC(AbstractTask):
|
|||
metrics.accuracy]
|
||||
metric_names = ["f1", "em"]
|
||||
|
||||
# generation_verbalizers = [{
|
||||
# "0": "no",
|
||||
# "1": "yes",
|
||||
# },
|
||||
# ]
|
||||
|
||||
verbalizers = [{
|
||||
verbalizers = {
|
||||
"0": {
|
||||
"0": "no",
|
||||
"1": "yes",
|
||||
}]
|
||||
templates_text = [
|
||||
"""question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}."""
|
||||
]
|
||||
}
|
||||
}
|
||||
templates_text = {
|
||||
"0": """question: {"meta":"question", "shortenable":False} answer: {"meta":"answer", "shortenable":False, "post_processing": lambda x:x+"."} paragraph: {"meta":"paragraph", "shortenable":True} The answer was {"mask"}."""
|
||||
}
|
||||
|
||||
|
||||
def load_dataset(self, split):
|
||||
|
@ -720,15 +717,16 @@ class SuperGLUEWIC(AbstractTask):
|
|||
metric = [metrics.accuracy]
|
||||
metric_names = ["accuracy"]
|
||||
|
||||
verbalizers = [{
|
||||
verbalizers = {
|
||||
"0": {
|
||||
"0": "No",
|
||||
"1": "Yes",
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
templates_text = [
|
||||
"""sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}.
|
||||
"""
|
||||
]
|
||||
templates_text = {
|
||||
"0": """sentence1: {"meta":"sentence1"} sentence2: {"meta":"sentence2", "shortenable": True} word: {"meta":"word"} {"mask"}."""
|
||||
}
|
||||
|
||||
def load_dataset(self, split):
|
||||
if self.data_args.datasets_load_from_disk:
|
||||
|
@ -786,7 +784,7 @@ class SuperGLUEWIC(AbstractTask):
|
|||
# text = self._mark_span(text, example['span2_text'], span2_index, '#')
|
||||
# src_texts = ["text:", text]
|
||||
# tgt_texts = [str(example["label"])]
|
||||
# return self.seq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
# return self.fseq2seq_format(src_texts, tgt_texts, add_prefix)
|
||||
|
||||
|
||||
class SuperGLUERecord(AbstractTask):
|
||||
|
@ -875,9 +873,9 @@ TASK_MAPPING = OrderedDict(
|
|||
|
||||
class AutoTask:
|
||||
@classmethod
|
||||
def get(self, task, config, data_args, tokenizer,predict_with_generate, seed=42):
|
||||
def get(self, task, config, data_args, seed=42):
|
||||
if task in TASK_MAPPING:
|
||||
return TASK_MAPPING[task](config, data_args, tokenizer,predict_with_generate, seed)
|
||||
return TASK_MAPPING[task](config, data_args, seed)
|
||||
raise ValueError(
|
||||
"Unrecognized task {} for AutoTask Model: {}.\n"
|
||||
"Task name should be one of {}.".format(
|
||||
|
|
|
@ -34,6 +34,7 @@ from transformers import (
|
|||
AutoModelForMaskedLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
HfArgumentParser,
|
||||
MBartTokenizer,
|
||||
default_data_collator,
|
||||
|
@ -41,8 +42,8 @@ from transformers import (
|
|||
)
|
||||
from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
||||
# from ..seq2seq.utils import get_adapter_config
|
||||
from examples_prompt.data_processors import AutoTask, TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
||||
from examples_prompt.seq2seq_trainer import Seq2SeqTrainer
|
||||
from examples_prompt.data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
||||
from transformers import Seq2SeqTrainer
|
||||
# from training_args import AdapterTrainingArguments
|
||||
from examples_prompt.trainers.trainer_utils import save_training_config
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -56,7 +57,8 @@ import json
|
|||
import numpy as np
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
import os
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
|
||||
"cola": ['matthews_correlation'],
|
||||
|
@ -109,7 +111,6 @@ class RemainArgHfArgumentParser(HfArgumentParser):
|
|||
|
||||
|
||||
def main():
|
||||
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
@ -202,6 +203,16 @@ def main():
|
|||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
|
@ -212,11 +223,6 @@ def main():
|
|||
)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
from openprompt.prompts import ManualTemplate
|
||||
from openprompt.plms import MLMTokenizerWrapper
|
||||
template = ManualTemplate(tokenizer, text="""sentence1: {"meta": 'premise'}, sentence2:,"""+
|
||||
"""{"meta":"hypothesis", "shortenable":True}, The answer was {"mask"} .""")
|
||||
tokenizer_wrapper = MLMTokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method='tail')
|
||||
|
||||
|
||||
|
||||
|
@ -248,49 +254,185 @@ def main():
|
|||
|
||||
# Temporarily set max_target_length for training.
|
||||
#max_target_length = data_args.max_target_length
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
|
||||
def preprocess_function(examples, **kwargs):
|
||||
# max_target_length += 1
|
||||
tokenizer = kwargs['tokenizer']
|
||||
data_args = kwargs['data_args']
|
||||
|
||||
|
||||
|
||||
|
||||
print("max_length", data_args.max_source_length)
|
||||
model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
|
||||
padding="max_length", truncation=True)
|
||||
|
||||
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
|
||||
# model_inputs["mask_position"] = mask_position
|
||||
model_inputs["extra_fields"] = examples['extra_fields']
|
||||
# from IPython import embed; embed(header="Therer")
|
||||
return model_inputs
|
||||
|
||||
|
||||
column_names = ['source', 'target', 'label', 'extra_fields']
|
||||
performance_metrics = {}
|
||||
|
||||
|
||||
|
||||
|
||||
def get_prompts(task, tokenizer, predict_with_generate, template_id="0", verbalizer_id="0"):
|
||||
# tid = getattr(config, "template_id", "0")
|
||||
# vid = getattr(config, "verbalizer_id", "0")
|
||||
from openpromptu.prompts import GenerationVerbalizer, ManualVerbalizer
|
||||
from openpromptu.prompts import ManualTemplate
|
||||
template = ManualTemplate(text = task.templates_text[template_id])
|
||||
if predict_with_generate:
|
||||
verbalizer = GenerationVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||
else:
|
||||
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||
# max_target_length = self.get_max_target_length(self.default_max_length)
|
||||
|
||||
from openpromptu import TokenizerWrapper
|
||||
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||
return template, verbalizer, tokenizer_wrapper
|
||||
|
||||
|
||||
from openpromptu.data_utils import InputExample
|
||||
|
||||
max_target_length = 32
|
||||
|
||||
if os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
||||
mask_token_func = lambda i: tokenizer.additional_special_tokens[i]
|
||||
def preprocess_function(raw_example, **kwargs):
|
||||
# max_target_length += 1
|
||||
tokenizer = kwargs['tokenizer']
|
||||
data_args = kwargs['data_args']
|
||||
template = kwargs['template']
|
||||
verbalizer = kwargs['verbalizer']
|
||||
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||
split = kwargs['split']
|
||||
# extra_fileds = example['extra_fields']
|
||||
|
||||
example = InputExample(**raw_example)
|
||||
|
||||
# from collections import namedtuple
|
||||
# example['tgt_text'] = ""
|
||||
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
||||
try:
|
||||
example = verbalizer.wrap_one_example(example)
|
||||
example, other = template.wrap_one_example(example)
|
||||
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||
model_inputs = tokenizer(input_sentence, max_length=256,
|
||||
padding="max_length", truncation=True)
|
||||
except:
|
||||
from IPython import embed; embed(header="Therer")
|
||||
|
||||
|
||||
# if split == "train":
|
||||
with tokenizer.as_target_tokenizer():
|
||||
label = tokenizer(other['tgt_text']).input_ids
|
||||
# label = [l if l != tokenizer.pad_token_id else -100 for l in label]
|
||||
|
||||
# from IPython import embed; embed(header="Therer")
|
||||
model_inputs["labels"] = label
|
||||
# else:
|
||||
# # from IPython import embed; embed(header="Therer")
|
||||
# model_inputs["tgt_text"] = other['tgt_text']
|
||||
# model_inputs['labels'] = None # model_inputs["extra_fields"] = extra_fileds
|
||||
# from IPython import embed; embed(header="Therer2")
|
||||
return model_inputs
|
||||
|
||||
def compute_metrics(eval_preds, tokenizer, dataset_name, eval_metric):
|
||||
# from IPython import embed; embed(header="In compute metrics")
|
||||
preds, labels = eval_preds
|
||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
# post_processor = .get(data_args.dataset_name[0], tokenizer,
|
||||
# data_args.ignore_pad_token_for_loss)
|
||||
# decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
|
||||
result = {}
|
||||
for metric in eval_metric:
|
||||
result.update(metric(decoded_preds, decoded_labels))
|
||||
|
||||
average_metric = sum(result.values())/len(result)
|
||||
result.update({"average_metrics":average_metric})
|
||||
return result
|
||||
|
||||
|
||||
|
||||
elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
|
||||
or os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
||||
mask_token_func = lambda i: tokenizer.mask_token
|
||||
def preprocess_function(raw_example, **kwargs):
|
||||
# max_target_length += 1
|
||||
|
||||
# from IPython import embed; embed(header="Therer")
|
||||
tokenizer = kwargs['tokenizer']
|
||||
|
||||
data_args = kwargs['data_args']
|
||||
template = kwargs['template']
|
||||
verbalizer = kwargs['verbalizer']
|
||||
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||
|
||||
example = InputExample(**raw_example)
|
||||
|
||||
# from collections import namedtuple
|
||||
# example['tgt_text'] = ""
|
||||
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
||||
# try:
|
||||
# example = verbalizer.wrap_one_example(example)
|
||||
example, other = template.wrap_one_example(example)
|
||||
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||
model_inputs = tokenizer(input_sentence, max_length=256,
|
||||
padding="max_length", truncation=True)
|
||||
|
||||
|
||||
|
||||
# print("max_length", data_args.max_source_length)
|
||||
# model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
|
||||
# padding="max_length", truncation=True)
|
||||
|
||||
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
|
||||
# model_inputs["mask_position"] = mask_position
|
||||
# model_inputs["extra_fields"] = examples['extra_fields']
|
||||
# from IPython import embed; embed(header="Therer")
|
||||
return model_inputs
|
||||
|
||||
def compute_metrics(eval_preds, dataset_name):
|
||||
# from IPython import embed; embed(header="In compute metrics")
|
||||
|
||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
|
||||
preds = np.argmax(preds, axis=-1)
|
||||
|
||||
result = {}
|
||||
average_metrics = []
|
||||
for metric in eval_metric:
|
||||
metric_item = metric(preds, labels)
|
||||
metric_value = list(metric_item.values())
|
||||
result.update(metric_item)
|
||||
average_metrics.extend(metric_value)
|
||||
print("average:",average_metrics)
|
||||
average_metric = sum(average_metrics)/len(average_metrics)
|
||||
result.update({"average_metrics":average_metric})
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if training_args.do_train:
|
||||
|
||||
train_task = AutoTask.get(data_args.task_name,
|
||||
data_args.dataset_config_name,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer,
|
||||
predict_with_generate=training_args.predict_with_generate,
|
||||
# tokenizer=tokenizer,
|
||||
# predict_with_generate=training_args.predict_with_generate,
|
||||
seed=data_args.data_seed)
|
||||
|
||||
train_dataset = train_task.get(split='train',
|
||||
split_validation_test=training_args.split_validation_test,
|
||||
n_obs=data_args.max_train_samples)
|
||||
|
||||
template, verbalizer, tokenizer_wrapper = get_prompts(train_task, tokenizer, training_args.predict_with_generate)
|
||||
|
||||
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
functools.partial(preprocess_function,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer),
|
||||
batched=True,
|
||||
tokenizer=tokenizer,
|
||||
template=template,
|
||||
verbalizer=verbalizer,
|
||||
tokenizer_wrapper=tokenizer_wrapper,
|
||||
split="train"),
|
||||
batched=False,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
|
@ -298,6 +440,8 @@ def main():
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
eval_splits_names = []
|
||||
|
||||
if training_args.do_eval:
|
||||
|
@ -306,87 +450,49 @@ def main():
|
|||
eval_splits_names.append("test")
|
||||
eval_splits = {}
|
||||
for split_name in eval_splits_names:
|
||||
_tasks = {dataset_name: AutoTask.get(dataset_name,
|
||||
dataset_config_name,
|
||||
eval_task = AutoTask.get(data_args.task_name,
|
||||
data_args.dataset_config_name,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer,
|
||||
predict_with_generate=training_args.predict_with_generate,
|
||||
# tokenizer=tokenizer,
|
||||
# predict_with_generate=training_args.predict_with_generate,
|
||||
seed=data_args.data_seed)
|
||||
for dataset_name, dataset_config_name\
|
||||
in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
|
||||
# for dataset_name, dataset_config_name\
|
||||
# in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
|
||||
|
||||
_datasets = {dataset_name: task.get(split=split_name,
|
||||
eval_dataset = eval_task.get(split=split_name,
|
||||
split_validation_test=training_args.split_validation_test,
|
||||
n_obs=data_args.max_train_samples)
|
||||
for dataset_name, task in _tasks.items()
|
||||
}
|
||||
|
||||
_datasets = {dataset_name: d.map(
|
||||
|
||||
|
||||
template, _verbalizer, tokenizer_wrapper = get_prompts(eval_task, tokenizer, training_args.predict_with_generate)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
functools.partial(preprocess_function,
|
||||
data_args=data_args,
|
||||
tokenizer=tokenizer),
|
||||
batched=True,
|
||||
tokenizer=tokenizer,
|
||||
template=template,
|
||||
verbalizer=_verbalizer,
|
||||
tokenizer_wrapper=tokenizer_wrapper,
|
||||
split=split_name),
|
||||
batched=False,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=[x for x in d.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
||||
remove_columns=[x for x in eval_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
for dataset_name, d in _datasets.items()
|
||||
}
|
||||
|
||||
eval_splits[split_name] = _datasets
|
||||
|
||||
eval_splits[split_name] = eval_dataset
|
||||
if split_name == "test":
|
||||
eval_metrics = {dataset_name:task.metric for dataset_name, task in _tasks.items()}
|
||||
verbalizers = {dataset_name:task.verbalizer for dataset_name, task in _tasks.items()}
|
||||
eval_metric = eval_task.metric
|
||||
verbalizer = _verbalizer
|
||||
|
||||
|
||||
# Metric, we assume we have only one training task.
|
||||
# eval_metrics = [task.metric for task in
|
||||
# for dataset_name, dataset_config_name in zip(data_args.dataset_name, data_args.dataset_config_name)][0]
|
||||
|
||||
# Extracts the extra information needed to evaluate on each dataset.
|
||||
# These information are only used in the compute_metrics.
|
||||
# We will assume that the test/eval dataloader does not change the order of
|
||||
# the data.
|
||||
# data_info = {"eval": eval_datasets[data_args.eval_dataset_name[0]]['extra_fields'],
|
||||
# "test": test_datasets[data_args.test_dataset_name[0]]['extra_fields'],
|
||||
# "train": train_dataset['extra_fields']}
|
||||
def compute_metrics(eval_preds, dataset_name):
|
||||
|
||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||
|
||||
preds = np.argmax(preds, axis=-1)
|
||||
|
||||
result = {}
|
||||
average_metrics = []
|
||||
for metric in eval_metrics[dataset_name]:
|
||||
metric_item = metric(preds, labels)
|
||||
metric_value = list(metric_item.values())
|
||||
result.update(metric_item)
|
||||
average_metrics.extend(metric_value)
|
||||
print("average:",average_metrics)
|
||||
average_metric = sum(average_metrics)/len(average_metrics)
|
||||
result.update({"average_metrics":average_metric})
|
||||
return result
|
||||
|
||||
# from IPython import embed; embed(header="isseq2seq")
|
||||
# Initialize our Trainer
|
||||
# if training_args.is_seq2seq == True:
|
||||
# trainer = Seq2SeqTrainer(
|
||||
# model=model,
|
||||
# args=training_args,
|
||||
# delta_args=delta_args,
|
||||
# train_dataset=splits['train'] if training_args.do_train else None,
|
||||
# eval_dataset=list(splits['validation'].values())[0] if training_args.do_eval else None,
|
||||
# data_info = data_info,
|
||||
# tokenizer=tokenizer,
|
||||
# data_collator=data_collator,
|
||||
# compute_metrics=compute_metrics if training_args.predict_with_generate else None,
|
||||
# evaluation_metrics = TASK_TO_METRICS[data_args.dataset_name[0]],
|
||||
# )
|
||||
# else:
|
||||
|
||||
class MLMTrainer(Trainer):
|
||||
_verbalizers = verbalizers
|
||||
def __init__(self, verbalizer=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.verbalizer=verbalizer
|
||||
|
||||
# def training_step(self, model, inputs):
|
||||
# from IPython import embed; embed(header="in trainstep")
|
||||
|
@ -394,7 +500,7 @@ def main():
|
|||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
|
||||
labels = inputs.pop('labels')
|
||||
extra_fields = inputs.pop("extra_fields")
|
||||
# extra_fields = inputs.pop("extra_fields")
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.get("logits")
|
||||
input_ids = inputs['input_ids']
|
||||
|
@ -402,20 +508,8 @@ def main():
|
|||
|
||||
|
||||
# from IPython import embed; embed(header="382")
|
||||
verbalizer = self._verbalizers[extra_fields[0]['dataset_name']].cuda()
|
||||
verbalizer = self.verbalizer.cuda()
|
||||
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
||||
|
||||
# colidx = torch.where(input_ids == verbalizer.tokenizer.mask_token_id)[0].cpu()
|
||||
# print(colidx)
|
||||
# missing = set([i for i in range(input_ids.size(0))]) - set(colidx.numpy())
|
||||
# print(missing)
|
||||
# if len(missing) > 0:
|
||||
# print("missing")
|
||||
# missing = list(missing)[0]
|
||||
# input_ids_missing = input_ids[missing]
|
||||
# print(input_ids_missing)
|
||||
# missing_tokens = verbalizer.tokenizer.convert_ids_to_tokens(input_ids_missing)
|
||||
# print(missing_tokens)
|
||||
label_logits = verbalizer.process_logits(logits_at_mask)
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
# from IPython import embed; embed(header="In compute loss")
|
||||
|
@ -423,21 +517,136 @@ def main():
|
|||
outputs.logits = label_logits
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
class MySeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# from IPython import embed; embed(header="agag")
|
||||
|
||||
intlabel = inputs.pop('label')
|
||||
# extra_fields = inputs.pop("extra_fields")
|
||||
outputs = model(**inputs)
|
||||
# logits = outputs.get("logits")
|
||||
# input_ids = inputs['input_ids']
|
||||
|
||||
|
||||
|
||||
# # from IPython import embed; embed(header="382")
|
||||
# verbalizer = self._verbalizers.cuda()
|
||||
# logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
||||
# label_logits = verbalizer.process_logits(logits_at_mask)
|
||||
# loss_fct = torch.nn.CrossEntropyLoss()
|
||||
# # from IPython import embed; embed(header="In compute loss")
|
||||
# loss = loss_fct(label_logits, labels)
|
||||
# outputs.logits = label_logits
|
||||
if return_outputs:
|
||||
return (outputs.loss, outputs)
|
||||
else:
|
||||
return outputs.loss
|
||||
|
||||
|
||||
# def evaluate(
|
||||
# self,
|
||||
# eval_dataset: Optional[Dict[str, Dataset]] = None,
|
||||
# ignore_keys: Optional[List[str]] = None,
|
||||
# metric_key_prefix: str = "eval",
|
||||
# max_length: Optional[int] = None,
|
||||
# num_beams: Optional[int] = None,
|
||||
# ) -> Dict[str, float]:
|
||||
# # TODO: this also needs to be set per dataset
|
||||
# self._max_length = max_length
|
||||
# self._num_beams = num_beams
|
||||
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model, #nn.Module,
|
||||
inputs, #Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only, #: bool,
|
||||
ignore_keys, #: Optional[List[str]] = None,
|
||||
): #-> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||
labels (each being optional).
|
||||
"""
|
||||
if not self.args.predict_with_generate or prediction_loss_only:
|
||||
return super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
|
||||
|
||||
has_labels = "labels" in inputs
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
intlabel = inputs.pop('label')
|
||||
gen_kwargs = {
|
||||
"max_length": 10, # self._max_length if s is not None else self.model.config.max_length,
|
||||
"num_beams": 1 #self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
||||
}
|
||||
generated_tokens = self.model.generate(
|
||||
inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
**gen_kwargs,
|
||||
)
|
||||
# in case the batch is shorter than max length, the output should be padded
|
||||
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
if self.label_smoother is not None:
|
||||
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||
else:
|
||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||
else:
|
||||
loss = None
|
||||
|
||||
if self.args.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
labels = inputs["labels"]
|
||||
if labels.shape[-1] < gen_kwargs["max_length"]:
|
||||
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||
|
||||
# from IPython import embed; embed(header="In seqseqtrainer")
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
|
||||
# aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
||||
# # from IPython import embed; embed()
|
||||
# return aa
|
||||
from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
|
||||
class DataCollatorWithExtraFields(DataCollatorMixin):
|
||||
return_tensors: str = "pt"
|
||||
def torch_call(self, features):
|
||||
print(len(features))
|
||||
extra_fields = [f.pop('extra_fields') for f in features]
|
||||
batch = torch_default_data_collator(features)
|
||||
batch['extra_fields'] =extra_fields
|
||||
print(batch['input_ids'].size())
|
||||
print(batch['labels'].size())
|
||||
return batch
|
||||
# from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
|
||||
# class DataCollatorWithExtraFields(DataCollatorMixin):
|
||||
# return_tensors: str = "pt"
|
||||
# def torch_call(self, features):
|
||||
# # print(len(features))
|
||||
# # extra_fields = [f.pop('extra_fields') for f in features]
|
||||
# batch = torch_default_data_collator(features)
|
||||
# batch['extra_fields'] =extra_fields
|
||||
# # print(batch['input_ids'].size())
|
||||
# # print(batch['labels'].size())
|
||||
# return batch
|
||||
|
||||
|
||||
# from transformers.data.data_collator import DefaultDataCollator
|
||||
|
@ -455,15 +664,29 @@ def main():
|
|||
|
||||
training_args.remove_unused_columns = False
|
||||
|
||||
if os.path.basename(model_args.model_name_or_path).startswith("roberta") or \
|
||||
os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
||||
trainer = MLMTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_splits['eval'][data_args.task_name] if training_args.do_eval else None,
|
||||
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
||||
compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name),
|
||||
# tokenizer=tokenizer,
|
||||
data_collator=DataCollatorWithExtraFields(),
|
||||
tokenizer=tokenizer,
|
||||
# data_collator=DataCollatorWithExtraFields(),
|
||||
verbalizer=verbalizer,
|
||||
)
|
||||
elif os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
||||
trainer = MySeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer, dataset_name=data_args.task_name, eval_metric=eval_metric),
|
||||
tokenizer=tokenizer,
|
||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Saves training config.
|
||||
|
@ -522,24 +745,23 @@ def main():
|
|||
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
for dataset_name, eval_dataset in eval_splits['eval'].items():
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset,
|
||||
|
||||
metrics = trainer.evaluate(eval_dataset=eval_splits['eval'],
|
||||
)
|
||||
trainer.log_metrics(f"{dataset_name}_eval", metrics)
|
||||
trainer.save_metrics(f"{dataset_name}_eval", metrics)
|
||||
all_results['evaluate'][dataset_name] = metrics
|
||||
trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
|
||||
trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
|
||||
all_results['evaluate'][data_args.task_name] = metrics
|
||||
|
||||
# Test
|
||||
all_results['test'] = {}
|
||||
if training_args.do_test:
|
||||
logger.info("*** Test ***")
|
||||
for dataset_name, test_dataset in eval_splits['test'].items():
|
||||
metrics = trainer.evaluate(eval_dataset=test_dataset,
|
||||
metrics = trainer.evaluate(eval_dataset=eval_splits['test'],
|
||||
metric_key_prefix="test"
|
||||
)
|
||||
trainer.log_metrics(f"{dataset_name}_test", metrics)
|
||||
trainer.save_metrics(f"{dataset_name}_test", metrics)
|
||||
all_results['test'][dataset_name] = metrics
|
||||
trainer.log_metrics(f"{data_args.task_name}_test", metrics)
|
||||
trainer.save_metrics(f"{data_args.task_name}_test", metrics)
|
||||
all_results['test'][data_args.task_name] = metrics
|
||||
|
||||
# repo_name = create_hub_repo_name(root="DeltaHub",
|
||||
# dataset=data_args.task_name,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
cd configs/
|
||||
python config_gen.py --job $3
|
||||
|
||||
python configs/config_gen.py --job $3
|
||||
echo "Regenerate config"
|
||||
cd ../
|
||||
|
||||
files=(cola mnli mrpc qnli qqp rte sst2 stsb superglue-boolq superglue-cb superglue-copa superglue-multirc superglue-record superglue-wic superglue-wsc.fixed)
|
||||
for ((i=$1; i<=$2; i++))
|
||||
do
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
__version__ = "0.0.1"
|
||||
__version__ = "0.0.4"
|
||||
|
||||
class GlobalSetting:
|
||||
def __init__(self):
|
||||
|
|
|
@ -728,6 +728,9 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
elif _delta_info['method'] == "insert_sequential":
|
||||
self.insert_sequential_module(module=submodule,
|
||||
_delta_info=_delta_info)
|
||||
elif _delta_info['method'] == "insert_parallel":
|
||||
self.insert_parallel_module(module=submodule,
|
||||
_delta_info=_delta_info)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -765,7 +768,13 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
submodule.forward = submodule.forward.__wrapped__
|
||||
delattr(submodule, _delta_info["delta_name"])
|
||||
else:
|
||||
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It'ss not a wrapped function.".format(name))
|
||||
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
|
||||
elif _delta_info['method'] == "insert_parallel":
|
||||
if hasattr(submodule.forward, "__wrapped__"):
|
||||
submodule.forward = submodule.forward.__wrapped__
|
||||
delattr(submodule, _delta_info["delta_name"])
|
||||
else:
|
||||
raise AttributeError("submodule {}'s forward has no attribute __wrapped__. It's not a wrapped function.".format(name))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -776,4 +785,3 @@ class DeltaBase(nn.Module, SaveLoadMixin):
|
|||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ class LowRankLinear(nn.Module):
|
|||
self.r = r
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_dropout = lora_dropout
|
||||
self.lin = nn.Linear(in_features, out_features) #
|
||||
if lora_dropout > 0.:
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
|
@ -37,7 +36,6 @@ class LowRankLinear(nn.Module):
|
|||
self.lora_A = nn.Parameter(weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
self.lin.reset_parameters() #
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from examples_prompt.metrics.metrics import exact_match
|
||||
from opendelta.utils.signature import get_arg_names, get_arg_names_inside_func
|
||||
from opendelta.utils.name_based_addressing import *
|
||||
from opendelta.utils.cuda import get_device
|
||||
|
@ -47,6 +48,7 @@ class SoftPromptLayer(nn.Module):
|
|||
soft_token_num: int = 100,
|
||||
raw_embedding: Optional[torch.Tensor] = None,
|
||||
init_range: Optional[float] = 0.5,
|
||||
other_expand_ids: Optional[Dict] = {"attention_mask":1, "token_type_ids":0},
|
||||
token_init = False,
|
||||
pad_id = 0,
|
||||
device: Optional[str]=None,
|
||||
|
@ -59,10 +61,13 @@ class SoftPromptLayer(nn.Module):
|
|||
self.pad_id = pad_id
|
||||
self.token_init = token_init
|
||||
self.device = device
|
||||
self.other_expand_ids = other_expand_ids
|
||||
|
||||
assert self.num_tokens>0
|
||||
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
||||
|
||||
self.all_pseudo_tokens = {}
|
||||
|
||||
def pre_forward(self, *args, **kwargs):
|
||||
# if attention_mask is passed as PLM's input, modify it here
|
||||
if 'encoder_outputs' in kwargs and kwargs['encoder_outputs'] is not None:
|
||||
|
@ -100,8 +105,19 @@ class SoftPromptLayer(nn.Module):
|
|||
inputs_embeds = torch.cat([soft_embeds, inputs_embeds], 1)
|
||||
kwargs['inputs_embeds'] = inputs_embeds
|
||||
|
||||
am = kwargs['attention_mask']
|
||||
am.data = torch.cat([torch.ones((*am.shape[:-1], inputs_embeds.shape[-2]-am.shape[-1]), dtype = am.dtype,device=am.device), am], dim=-1)
|
||||
for expand_key in self.other_expand_ids:
|
||||
if expand_key in kwargs:
|
||||
real_tokens = kwargs[expand_key]
|
||||
if expand_key in self.all_pseudo_tokens:
|
||||
pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
|
||||
else:
|
||||
pseudo_tokens_value = self.other_expand_ids[expand_key]
|
||||
pseudo_tokens = torch.ones(
|
||||
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
|
||||
dtype = real_tokens.dtype,
|
||||
device=real_tokens.device) * pseudo_tokens_value
|
||||
self.all_pseudo_tokens[expand_key] = pseudo_tokens
|
||||
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
|
||||
|
||||
return args, kwargs
|
||||
|
||||
|
@ -136,6 +152,10 @@ class SoftPromptModel(DeltaBase):
|
|||
soft_token_num (:obj:`int`, *optional*): num of new tokens to add in the front of the input.
|
||||
init_range (:obj:`float`, *optional*): If initialize new tokens randomly, the random range of uniform distribution.
|
||||
token_init (:obj:`bool`, *optional*, default to :obj:`True`): Whether to initialize the new tokens with tokens of the plm
|
||||
other_expand_ids (:obj:`dict`, *optional*, default to `{"attention_mask":1, "token_type_ids":0}`) The name of
|
||||
other tokens and its default value that expand along with the input sequence. For example, when
|
||||
you prepend 100 tokens to the input_ids, the attention_mask should be extended, and the token_type_ids should
|
||||
be extended as well.
|
||||
modified_modules (:obj:`List[str]`): For prefix tuning, the it must refer to an attention layer (Currently, only
|
||||
the implemented ones)
|
||||
unfrozen_modules (:obj:`List[str]`, *optional*, default to :obj:`None`): The modules that should be unfrozen
|
||||
|
@ -151,6 +171,7 @@ class SoftPromptModel(DeltaBase):
|
|||
soft_token_num=100,
|
||||
init_range = 0.5,
|
||||
token_init=True,
|
||||
other_expand_ids={"attention_mask":1, "token_type_ids":0},
|
||||
modified_modules: Optional[List[str]] = None,
|
||||
exclude_modules: Optional[List[str]] = None,
|
||||
unfrozen_modules: Optional[List[str]] = None,
|
||||
|
@ -202,6 +223,7 @@ class SoftPromptModel(DeltaBase):
|
|||
soft_prompt_layer = SoftPromptLayer(
|
||||
soft_token_num = self.soft_token_num,
|
||||
raw_embedding = self.raw_embedding,
|
||||
other_expand_ids = self.other_expand_ids,
|
||||
token_init = self.token_init,
|
||||
init_range = self.init_range,
|
||||
device = module_device,
|
||||
|
|
2
setup.py
2
setup.py
|
@ -17,7 +17,7 @@ print(requires)
|
|||
with open('README.md', 'r') as f:
|
||||
setuptools.setup(
|
||||
name = 'opendelta',
|
||||
version = '0.0.3',
|
||||
version = '0.0.4',
|
||||
description = "An open source framework for delta learning (parameter efficient learning).",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
Loading…
Reference in New Issue