LiResolver/RE/data/dialogue.py

346 lines
16 KiB
Python

from .base_data_module import BaseDataModule
from .processor import get_dataset, processors
from transformers import AutoTokenizer
from dataclasses import dataclass
from torch.utils.data import DataLoader
import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from transformers.file_utils import PaddingStrategy
from transformers.models.bert import BertTokenizer, BertTokenizerFast
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
@dataclass
class DataCollatorForSeq2Seq:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
model (:class:`~transformers.PreTrainedModel`):
The model that is being trained. If set and has the `prepare_decoder_input_ids_from_labels`, use it to
prepare the `decoder_input_ids`
This is useful when using `label_smoothing` to avoid calculating loss twice.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
import numpy as np
if return_tensors is None:
return_tensors = self.return_tensors
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# prepare decoder_input_ids
if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
features["decoder_input_ids"] = decoder_input_ids
return features
class DIALOGUE(BaseDataModule):
def __init__(self, args) -> None:
super().__init__(args)
self.processor = processors[self.args.task_name](self.args.data_dir, self.args.use_prompt)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path)
self.num_labels = len(self.processor.get_labels())
class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
unused_list = [f"[unused{i}]" for i in range(1,50)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': unused_list})
speaker_list = [f"[speaker{i}]" for i in range(1,50)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': speaker_list})
so_list = ["[sub]", "[obj]"]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})
def setup(self, stage=None):
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def prepare_data(self):
pass
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--task_name", type=str, default="normal", help="[normal, reloss, ptune]")
parser.add_argument("--model_name_or_path", type=str, default="/home/xx/bert-base-uncased", help="Number of examples to operate on per forward step.")
parser.add_argument("--max_seq_length", type=int, default=512, help="Number of examples to operate on per forward step.")
parser.add_argument("--ptune_k", type=int, default=7, help="number of unused tokens in prompt")
return parser
class WIKI80(BaseDataModule):
def __init__(self, args, model=None) -> None:
super().__init__(args)
self.processor = processors[self.args.task_name](self.args.data_dir, self.args.use_prompt, self.args.ossl2_label_type)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path)
use_gpt = "gpt" in args.model_name_or_path
rel2id = self.processor.get_labels()
self.num_labels = len(rel2id)
entity_list = ["[object_start]", "[object_end]", "[subject_start]", "[subject_end]"]
class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
if use_gpt:
self.tokenizer.add_special_tokens({'cls_token': "[CLS]"})
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
so_list = ["[sub]", "[obj]"]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})
prompt_tokens = [f"[T{i}]" for i in range(1,6)]
self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})
def setup(self, stage=None):
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def setup_1(self, stage=None):
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
def setup_2(self, stage=None):
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def prepare_data(self):
pass
def get_tokenizer(self):
return self.tokenizer
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--task_name", type=str, default="wiki80", help="[normal, reloss, ptune]") # normal
parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Number of examples to operate on per forward step.") # /home/xx/bert-base-uncased
parser.add_argument("--max_seq_length", type=int, default=256, help="Number of examples to operate on per forward step.") # 512
parser.add_argument("--ptune_k", type=int, default=7, help="number of unused tokens in prompt")
return parser
class SST2(BaseDataModule):
def __init__(self, args) -> None:
super().__init__(args)
self.processor = processors[self.args.task_name](self.args.data_dir, self.args.use_prompt)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path)
labels = self.processor.get_labels()
self.num_labels = len(labels)
class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
if args.CT_CL:
prompt_tokens = [f"[T{i}]" for i in range(1,6)]
self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.num_training_steps = len(self.data_train) // self.batch_size // self.args.accumulate_grad_batches * self.args.max_epochs
def setup(self, stage=None):
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def prepare_data(self):
pass
def get_tokenizer(self):
return self.tokenizer
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--task_name", type=str, default="normal", help="[normal, reloss, ptune]")
parser.add_argument("--model_name_or_path", type=str, default="/home/xx/bert-base-uncased", help="Number of examples to operate on per forward step.")
parser.add_argument("--max_seq_length", type=int, default=512, help="Number of examples to operate on per forward step.")
parser.add_argument("--ptune_k", type=int, default=7, help="number of unused tokens in prompt")
return parser
class BartREDataset(BaseDataModule):
def __init__(self, args, model=None) -> None:
super().__init__(args)
self.processor = processors[self.args.task_name](self.args.data_dir, self.args.use_prompt)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path)
rel2id = self.processor.get_labels()
self.num_labels = len(rel2id)
entity_list = ["[object_start]", "[object_end]", "[subject_start]", "[subject_end]"]
class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
so_list = ["[sub]", "[obj]"]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})
prompt_tokens = [f"[T{i}]" for i in range(1,6)]
self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})
if "t5" in self.args.model_name_or_path:
self.tokenizer.add_special_tokens({'mask_token': "<mask>"})
self.collate_fn = DataCollatorForSeq2Seq(self.tokenizer,
model=model,
label_pad_token_id=self.tokenizer.pad_token_id,
pad_to_multiple_of=8 if self.args.fp16 else None,
padding="longest",
max_length=self.args.max_seq_length
)
def setup(self, stage=None):
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def prepare_data(self):
pass
def get_tokenizer(self):
return self.tokenizer
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--task_name", type=str, default="normal", help="[normal, reloss, ptune]")
parser.add_argument("--model_name_or_path", type=str, default="/home/xx/bert-base-uncased", help="Number of examples to operate on per forward step.")
parser.add_argument("--max_seq_length", type=int, default=512, help="Number of examples to operate on per forward step.")
parser.add_argument("--ptune_k", type=int, default=7, help="number of unused tokens in prompt")
return parser
def train_dataloader(self):
dataloader = DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, collate_fn=self.collate_fn)
return dataloader
def val_dataloader(self):
return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, collate_fn=self.collate_fn)
def test_dataloader(self):
return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=False, collate_fn=self.collate_fn)
class SST2(BaseDataModule):
def __init__(self, args) -> None:
super().__init__(args)
self.processor = processors[self.args.task_name](self.args.data_dir, self.args.use_prompt)
self.tokenizer = AutoTokenizer.from_pretrained(self.args.model_name_or_path)
labels = self.processor.get_labels()
self.num_labels = len(labels)
class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]
num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
if args.CT_CL:
prompt_tokens = [f"[T{i}]" for i in range(1,6)]
self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})
self.data_train = get_dataset("train", self.args, self.tokenizer, self.processor)
self.num_training_steps = len(self.data_train) // self.batch_size // self.args.accumulate_grad_batches * self.args.max_epochs
def setup(self, stage=None):
self.data_val = get_dataset("dev", self.args, self.tokenizer, self.processor)
self.data_test = get_dataset("test", self.args, self.tokenizer, self.processor)
def prepare_data(self):
pass
def get_tokenizer(self):
return self.tokenizer
@staticmethod
def add_to_argparse(parser):
BaseDataModule.add_to_argparse(parser)
parser.add_argument("--task_name", type=str, default="normal", help="[normal, reloss, ptune]")
parser.add_argument("--model_name_or_path", type=str, default="/home/xx/bert-base-uncased", help="Number of examples to operate on per forward step.")
parser.add_argument("--max_seq_length", type=int, default=512, help="Number of examples to operate on per forward step.")
parser.add_argument("--ptune_k", type=int, default=7, help="number of unused tokens in prompt")
return parser