update examples
This commit is contained in:
parent
0de3cbd31d
commit
42747f2b81
|
@ -25,7 +25,7 @@ log.txt
|
||||||
!examples/jupyter_notebook_examples/*.py
|
!examples/jupyter_notebook_examples/*.py
|
||||||
|
|
||||||
|
|
||||||
!**/examples/*/configs/config_gen.py
|
!**/examples/*/configs/*.py
|
||||||
**/outputs_search/**/*.bin
|
**/outputs_search/**/*.bin
|
||||||
**/outputs_search/**/*.pt
|
**/outputs_search/**/*.pt
|
||||||
|
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 058e5f25c898a1f956e3f17a0db6d62f08173e7f
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 3a5083d61e73bae607574a3047deafaa76b97646
|
|
|
@ -1,50 +0,0 @@
|
||||||
<!---
|
|
||||||
Copyright 2021 The HuggingFace Team. All rights reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
-->
|
|
||||||
|
|
||||||
# Use OpenDelta in vision transformer ViT
|
|
||||||
|
|
||||||
This example uses the [huggingface image classification examples](), by adding several
|
|
||||||
lines in the original scripts.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
### 1. install necessary package
|
|
||||||
```shell
|
|
||||||
pip install Pillow
|
|
||||||
pip install torchvision
|
|
||||||
pip install transformers==4.16.2
|
|
||||||
pip install datsets==1.18.0
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. run
|
|
||||||
```bash
|
|
||||||
python run_image_classification.py configs/lora_beans.json
|
|
||||||
```
|
|
||||||
|
|
||||||
Do not forget to re-install datasets back into 1.17.0 for other examples. :)
|
|
||||||
|
|
||||||
|
|
||||||
## Possible Errors
|
|
||||||
1. dataset connection error
|
|
||||||
|
|
||||||
Solution 1: open a python console, running the error command again, may not be useful
|
|
||||||
|
|
||||||
Solution 2: download the dataset by yourself on a internect connected machine, saved to disk and transfer to your server, at last load_from_disk.
|
|
||||||
|
|
||||||
|
|
||||||
## Link to original training scripts
|
|
||||||
You may find solution to other question about the scripts and irrelevant to Opendelta in
|
|
||||||
https://github.com/huggingface/transformers/tree/master/examples/pytorch/image-classification
|
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
{
|
|
||||||
"report_to": "none",
|
|
||||||
"dataset_name": "beans",
|
|
||||||
"output_dir": "./beans_outputs/",
|
|
||||||
"do_train": true,
|
|
||||||
"do_eval": true,
|
|
||||||
"num_train_epochs": 5,
|
|
||||||
"remove_unused_columns": false,
|
|
||||||
"per_device_train_batch_size": 8,
|
|
||||||
"per_device_eval_batch_size": 8,
|
|
||||||
"logging_strategy": "steps",
|
|
||||||
"logging_steps": 10,
|
|
||||||
"evaluation_strategy": "epoch",
|
|
||||||
"save_strategy": "epoch",
|
|
||||||
"load_best_model_at_end": true,
|
|
||||||
"save_total_limit": 3,
|
|
||||||
"seed": 1337,
|
|
||||||
"delta_type": "lora",
|
|
||||||
"modified_modules": [
|
|
||||||
"attention.query",
|
|
||||||
"attention.value"
|
|
||||||
],
|
|
||||||
"unfrozen_modules": [
|
|
||||||
"classifier",
|
|
||||||
"deltas"
|
|
||||||
],
|
|
||||||
"overwrite_output_dir": true,
|
|
||||||
"learning_rate": 5e-4
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,89 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Accuracy metric."""
|
|
||||||
|
|
||||||
from sklearn.metrics import accuracy_score
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
|
|
||||||
|
|
||||||
_DESCRIPTION = """
|
|
||||||
Accuracy is the proportion of correct predictions among the total number of cases processed. It can be computed with:
|
|
||||||
Accuracy = (TP + TN) / (TP + TN + FP + FN)
|
|
||||||
TP: True positive
|
|
||||||
TN: True negative
|
|
||||||
FP: False positive
|
|
||||||
FN: False negative
|
|
||||||
"""
|
|
||||||
|
|
||||||
_KWARGS_DESCRIPTION = """
|
|
||||||
Args:
|
|
||||||
predictions: Predicted labels, as returned by a model.
|
|
||||||
references: Ground truth labels.
|
|
||||||
normalize: If False, return the number of correctly classified samples.
|
|
||||||
Otherwise, return the fraction of correctly classified samples.
|
|
||||||
sample_weight: Sample weights.
|
|
||||||
Returns:
|
|
||||||
accuracy: Accuracy score.
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
>>> accuracy_metric = datasets.load_metric("accuracy")
|
|
||||||
>>> results = accuracy_metric.compute(references=[0, 1], predictions=[0, 1])
|
|
||||||
>>> print(results)
|
|
||||||
{'accuracy': 1.0}
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CITATION = """\
|
|
||||||
@article{scikit-learn,
|
|
||||||
title={Scikit-learn: Machine Learning in {P}ython},
|
|
||||||
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
|
|
||||||
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
|
|
||||||
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
|
|
||||||
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
|
|
||||||
journal={Journal of Machine Learning Research},
|
|
||||||
volume={12},
|
|
||||||
pages={2825--2830},
|
|
||||||
year={2011}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
|
||||||
class Accuracy(datasets.Metric):
|
|
||||||
def _info(self):
|
|
||||||
return datasets.MetricInfo(
|
|
||||||
description=_DESCRIPTION,
|
|
||||||
citation=_CITATION,
|
|
||||||
inputs_description=_KWARGS_DESCRIPTION,
|
|
||||||
features=datasets.Features(
|
|
||||||
{
|
|
||||||
"predictions": datasets.Sequence(datasets.Value("int32")),
|
|
||||||
"references": datasets.Sequence(datasets.Value("int32")),
|
|
||||||
}
|
|
||||||
if self.config_name == "multilabel"
|
|
||||||
else {
|
|
||||||
"predictions": datasets.Value("int32"),
|
|
||||||
"references": datasets.Value("int32"),
|
|
||||||
}
|
|
||||||
),
|
|
||||||
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _compute(self, predictions, references, normalize=True, sample_weight=None):
|
|
||||||
return {
|
|
||||||
"accuracy": float(
|
|
||||||
accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
|
|
||||||
)
|
|
||||||
}
|
|
|
@ -1,3 +0,0 @@
|
||||||
# torch>=1.5.0
|
|
||||||
torchvision>=0.6.0
|
|
||||||
datasets>=1.8.0
|
|
|
@ -1,392 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms import (
|
|
||||||
CenterCrop,
|
|
||||||
Compose,
|
|
||||||
Normalize,
|
|
||||||
RandomHorizontalFlip,
|
|
||||||
RandomResizedCrop,
|
|
||||||
Resize,
|
|
||||||
ToTensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from transformers import (
|
|
||||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
|
||||||
AutoConfig,
|
|
||||||
AutoFeatureExtractor,
|
|
||||||
AutoModelForImageClassification,
|
|
||||||
HfArgumentParser,
|
|
||||||
Trainer,
|
|
||||||
TrainingArguments,
|
|
||||||
)
|
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
|
||||||
from transformers.utils import check_min_version
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
|
|
||||||
""" Fine-tuning a 🤗 Transformers model for image classification"""
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
||||||
check_min_version("4.16.0.dev0")
|
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
||||||
|
|
||||||
|
|
||||||
def pil_loader(path: str):
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
im = Image.open(f)
|
|
||||||
return im.convert("RGB")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DataTrainingArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
||||||
Using ``HfArgumentParser`` we can turn this class
|
|
||||||
into argparse arguments to be able to specify them on
|
|
||||||
the command line.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_name: Optional[str] = field(
|
|
||||||
default="nateraw/image-folder", metadata={"help": "Name of a dataset from the datasets package"}
|
|
||||||
)
|
|
||||||
dataset_config_name: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
||||||
)
|
|
||||||
train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."})
|
|
||||||
validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."})
|
|
||||||
train_val_split: Optional[float] = field(
|
|
||||||
default=0.15, metadata={"help": "Percent to split off of train for validation."}
|
|
||||||
)
|
|
||||||
max_train_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
|
||||||
"value if set."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
max_eval_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
|
||||||
"value if set."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
data_files = dict()
|
|
||||||
if self.train_dir is not None:
|
|
||||||
data_files["train"] = self.train_dir
|
|
||||||
if self.validation_dir is not None:
|
|
||||||
data_files["val"] = self.validation_dir
|
|
||||||
self.data_files = data_files if data_files else None
|
|
||||||
|
|
||||||
class RemainArgHfArgumentParser(HfArgumentParser):
|
|
||||||
def parse_json_file(self, json_file: str, return_remaining_args=True ):
|
|
||||||
"""
|
|
||||||
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
|
||||||
dataclass types.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
data = json.loads(Path(json_file).read_text())
|
|
||||||
outputs = []
|
|
||||||
for dtype in self.dataclass_types:
|
|
||||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
||||||
inputs = {k: data.pop(k) for k in list(data.keys()) if k in keys}
|
|
||||||
obj = dtype(**inputs)
|
|
||||||
outputs.append(obj)
|
|
||||||
|
|
||||||
remain_args = argparse.ArgumentParser()
|
|
||||||
remain_args.__dict__.update(data)
|
|
||||||
if return_remaining_args:
|
|
||||||
return (*outputs, remain_args)
|
|
||||||
else:
|
|
||||||
return (*outputs,)
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_name_or_path: str = field(
|
|
||||||
default="google/vit-base-patch16-224-in21k",
|
|
||||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
|
||||||
)
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
|
||||||
)
|
|
||||||
config_name: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
||||||
)
|
|
||||||
cache_dir: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
||||||
)
|
|
||||||
model_revision: str = field(
|
|
||||||
default="main",
|
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
||||||
)
|
|
||||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
|
||||||
use_auth_token: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
|
||||||
"with private models)."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
|
||||||
labels = torch.tensor([example["labels"] for example in examples])
|
|
||||||
return {"pixel_values": pixel_values, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
|
||||||
# or by passing the --help flag to this script.
|
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
||||||
|
|
||||||
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
# If we pass only one argument to the script and it's the path to a json file,
|
|
||||||
# let's parse it to get our arguments.
|
|
||||||
model_args, data_args, training_args, delta_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, delta_args = parser.parse_args_into_dataclasses()
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
|
||||||
)
|
|
||||||
|
|
||||||
log_level = training_args.get_process_log_level()
|
|
||||||
logger.setLevel(log_level)
|
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
|
||||||
transformers.utils.logging.enable_default_handler()
|
|
||||||
transformers.utils.logging.enable_explicit_format()
|
|
||||||
|
|
||||||
# Log on each process the small summary:
|
|
||||||
logger.warning(
|
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
|
||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
|
||||||
)
|
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
|
||||||
|
|
||||||
# Detecting last checkpoint.
|
|
||||||
last_checkpoint = None
|
|
||||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
||||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
||||||
"Use --overwrite_output_dir to overcome."
|
|
||||||
)
|
|
||||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
|
||||||
logger.info(
|
|
||||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
|
||||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize our dataset and prepare it for the 'image-classification' task.
|
|
||||||
ds = load_dataset(
|
|
||||||
data_args.dataset_name,
|
|
||||||
data_args.dataset_config_name,
|
|
||||||
data_files=data_args.data_files,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
task="image-classification",
|
|
||||||
)
|
|
||||||
# If you encounter error here, try to down load the dataset by yourself and load from disk
|
|
||||||
# like the following two lines
|
|
||||||
# from datasets import load_from_disk
|
|
||||||
# ds = load_from_disk(f"../../../../huggingface_datasets/saved_to_disk/{data_args.dataset_name}")
|
|
||||||
|
|
||||||
# If we don't have a validation split, split off a percentage of train as validation.
|
|
||||||
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
|
|
||||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
|
||||||
split = ds["train"].train_test_split(data_args.train_val_split)
|
|
||||||
ds["train"] = split["train"]
|
|
||||||
ds["validation"] = split["test"]
|
|
||||||
|
|
||||||
# Prepare label mappings.
|
|
||||||
# We'll include these in the model's config to get human readable labels in the Inference API.
|
|
||||||
labels = ds["train"].features["labels"].names
|
|
||||||
label2id, id2label = dict(), dict()
|
|
||||||
for i, label in enumerate(labels):
|
|
||||||
label2id[label] = str(i)
|
|
||||||
id2label[str(i)] = label
|
|
||||||
|
|
||||||
# Load the accuracy metric from the datasets package
|
|
||||||
# metric = datasets.load_metric("accuracy")
|
|
||||||
metric = datasets.load_metric("metric.py")
|
|
||||||
|
|
||||||
# Define our compute_metrics function. It takes an ``EvalPrediction`` object (a namedtuple with a
|
|
||||||
# predictions and label_ids field) and has to return a dictionary string to float.
|
|
||||||
def compute_metrics(p):
|
|
||||||
"""Computes accuracy on a batch of predictions"""
|
|
||||||
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_args.config_name or model_args.model_name_or_path,
|
|
||||||
num_labels=len(labels),
|
|
||||||
label2id=label2id,
|
|
||||||
id2label=id2label,
|
|
||||||
finetuning_task="image-classification",
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
model = AutoModelForImageClassification.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
||||||
config=config,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
model_args.feature_extractor_name or model_args.model_name_or_path,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if delta_args.delta_type.lower() != "none":
|
|
||||||
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
|
||||||
delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
|
||||||
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
|
|
||||||
delta_model.freeze_module(set_state_dict = True)
|
|
||||||
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image.
|
|
||||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
|
||||||
_train_transforms = Compose(
|
|
||||||
[
|
|
||||||
RandomResizedCrop(feature_extractor.size),
|
|
||||||
RandomHorizontalFlip(),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_val_transforms = Compose(
|
|
||||||
[
|
|
||||||
Resize(feature_extractor.size),
|
|
||||||
CenterCrop(feature_extractor.size),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_transforms(example_batch):
|
|
||||||
"""Apply _train_transforms across a batch."""
|
|
||||||
example_batch["pixel_values"] = [
|
|
||||||
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
|
|
||||||
]
|
|
||||||
return example_batch
|
|
||||||
|
|
||||||
def val_transforms(example_batch):
|
|
||||||
"""Apply _val_transforms across a batch."""
|
|
||||||
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
|
|
||||||
return example_batch
|
|
||||||
|
|
||||||
if training_args.do_train:
|
|
||||||
if "train" not in ds:
|
|
||||||
raise ValueError("--do_train requires a train dataset")
|
|
||||||
if data_args.max_train_samples is not None:
|
|
||||||
ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
|
|
||||||
# Set the training transforms
|
|
||||||
ds["train"].set_transform(train_transforms)
|
|
||||||
|
|
||||||
if training_args.do_eval:
|
|
||||||
if "validation" not in ds:
|
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
|
||||||
if data_args.max_eval_samples is not None:
|
|
||||||
ds["validation"] = (
|
|
||||||
ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
|
|
||||||
)
|
|
||||||
# Set the validation transforms
|
|
||||||
ds["validation"].set_transform(val_transforms)
|
|
||||||
|
|
||||||
# Initalize our trainer
|
|
||||||
trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=ds["train"] if training_args.do_train else None,
|
|
||||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
|
||||||
compute_metrics=compute_metrics,
|
|
||||||
tokenizer=feature_extractor,
|
|
||||||
data_collator=collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training
|
|
||||||
if training_args.do_train:
|
|
||||||
checkpoint = None
|
|
||||||
if training_args.resume_from_checkpoint is not None:
|
|
||||||
checkpoint = training_args.resume_from_checkpoint
|
|
||||||
elif last_checkpoint is not None:
|
|
||||||
checkpoint = last_checkpoint
|
|
||||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
|
||||||
trainer.save_model()
|
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
|
||||||
trainer.save_state()
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
if training_args.do_eval:
|
|
||||||
metrics = trainer.evaluate()
|
|
||||||
trainer.log_metrics("eval", metrics)
|
|
||||||
trainer.save_metrics("eval", metrics)
|
|
||||||
|
|
||||||
# Write model card and (optionally) push to hub
|
|
||||||
kwargs = {
|
|
||||||
"finetuned_from": model_args.model_name_or_path,
|
|
||||||
"tasks": "image-classification",
|
|
||||||
"dataset": data_args.dataset_name,
|
|
||||||
"tags": ["image-classification"],
|
|
||||||
}
|
|
||||||
if training_args.push_to_hub:
|
|
||||||
trainer.push_to_hub(**kwargs)
|
|
||||||
else:
|
|
||||||
trainer.create_model_card(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -0,0 +1,145 @@
|
||||||
|
from openpromptu.data_utils import InputExample
|
||||||
|
import torch
|
||||||
|
from transformers.data.data_collator import torch_default_data_collator
|
||||||
|
from transformers.data.data_collator import DataCollatorMixin as HfDataCollatorMixin
|
||||||
|
import numpy as np
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoModelForImageClassification,
|
||||||
|
)
|
||||||
|
from transformers import ViTFeatureExtractor
|
||||||
|
|
||||||
|
from transformers import Trainer as HfTrainer
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
def process_example(raw_example, **kwargs):
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
inputs = tokenizer(raw_example['image'], return_tensors='pt')
|
||||||
|
inputs['labels'] = raw_example['labels']
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def get_prompts(task, tokenizer, data_args, template_id="0", verbalizer_id="0"):
|
||||||
|
# from openpromptu.prompts import ManualVerbalizer
|
||||||
|
# from openpromptu.prompts import ManualTemplate
|
||||||
|
# from openpromptu import TokenizerWrapper
|
||||||
|
# template = ManualTemplate(text = task.templates_text[template_id])
|
||||||
|
# verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||||
|
# tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def preprocess_function(raw_example, **kwargs):
|
||||||
|
# from IPython import embed; embed(header="Therefa")
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
model_inputs = tokenizer(raw_example['image'], return_tensors='pt')
|
||||||
|
model_inputs['pixel_values'] = model_inputs['pixel_values'].squeeze()
|
||||||
|
model_inputs['labels'] = raw_example['labels']
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds, dataset_name, eval_metric):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
|
||||||
|
preds = np.argmax(preds, axis=-1)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
average_metrics = []
|
||||||
|
for metric in eval_metric:
|
||||||
|
metric_item = metric(preds, labels)
|
||||||
|
metric_value = list(metric_item.values())
|
||||||
|
result.update(metric_item)
|
||||||
|
average_metrics.extend(metric_value)
|
||||||
|
print("average:",average_metrics)
|
||||||
|
average_metric = sum(average_metrics)/len(average_metrics)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mask_token_func(tokenizer, ith_mask=0):
|
||||||
|
return tokenizer.mask_token
|
||||||
|
|
||||||
|
def get_remove_columns(dataset_features):
|
||||||
|
# dataset_features.pop("label")
|
||||||
|
print("remove_columns: {}".format(dataset_features))
|
||||||
|
return dataset_features
|
||||||
|
|
||||||
|
class DataCollator(HfDataCollatorMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.return_tensors='pt'
|
||||||
|
|
||||||
|
def torch_call(self, features):
|
||||||
|
# from IPython import embed; embed(header="in data collator")
|
||||||
|
a = torch_default_data_collator(features=features)
|
||||||
|
# from IPython import embed; embed(header="in data collator")
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
def get_backbone(model_args, **kwargs):
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
config.dropout_rate = 0.0
|
||||||
|
tokenizer = AutoFeatureExtractor.from_pretrained(
|
||||||
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
|
||||||
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForImageClassification.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
config.num_labels = model_args.num_classes
|
||||||
|
old_classifier = model.classifier
|
||||||
|
model.classifier = nn.Linear(old_classifier.in_features, config.num_labels)
|
||||||
|
|
||||||
|
|
||||||
|
return config, tokenizer, model
|
||||||
|
|
||||||
|
class Trainer(HfTrainer):
|
||||||
|
def __init__(self, verbalizer=None, eval_task=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.verbalizer=verbalizer
|
||||||
|
self.eval_task=eval_task
|
||||||
|
self.compute_metrics = self._compute_metrics
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
labels = inputs.pop('labels')
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
|
||||||
|
loss = self.loss_fn(logits, labels)
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
def _compute_metrics(self, eval_preds):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
|
||||||
|
preds = np.argmax(preds, axis=-1)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
average_metrics = []
|
||||||
|
for metric in self.eval_task.metric:
|
||||||
|
metric_item = metric(preds, labels)
|
||||||
|
metric_value = list(metric_item.values())
|
||||||
|
result.update(metric_item)
|
||||||
|
average_metrics.extend(metric_value)
|
||||||
|
print("average:",average_metrics)
|
||||||
|
average_metric = sum(average_metrics)/len(average_metrics)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
from IPython import embed; embed(header="In compute metrics")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,141 @@
|
||||||
|
from openpromptu.data_utils import InputExample
|
||||||
|
import torch
|
||||||
|
from transformers.data.data_collator import torch_default_data_collator
|
||||||
|
from transformers.data.data_collator import DataCollatorMixin as HfDataCollatorMixin
|
||||||
|
import numpy as np
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import Trainer as HfTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_function(raw_example, **kwargs):
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
data_args = kwargs['data_args']
|
||||||
|
template = kwargs['template']
|
||||||
|
verbalizer = kwargs['verbalizer']
|
||||||
|
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||||
|
|
||||||
|
example = InputExample(**raw_example)
|
||||||
|
example, other = template.wrap_one_example(example)
|
||||||
|
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||||
|
model_inputs = tokenizer(input_sentence, max_length=data_args.max_source_length,
|
||||||
|
padding="max_length", truncation=True)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def compute_metrics(eval_preds, dataset_name, eval_metric):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
|
||||||
|
preds = np.argmax(preds, axis=-1)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
average_metrics = []
|
||||||
|
for metric in eval_metric:
|
||||||
|
metric_item = metric(preds, labels)
|
||||||
|
metric_value = list(metric_item.values())
|
||||||
|
result.update(metric_item)
|
||||||
|
average_metrics.extend(metric_value)
|
||||||
|
print("average:",average_metrics)
|
||||||
|
average_metric = sum(average_metrics)/len(average_metrics)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mask_token_func(tokenizer, ith_mask=0):
|
||||||
|
return tokenizer.mask_token
|
||||||
|
|
||||||
|
def get_remove_columns(dataset_features):
|
||||||
|
dataset_features.pop("label")
|
||||||
|
return dataset_features
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompts(task, tokenizer, data_args, template_id="0", verbalizer_id="0"):
|
||||||
|
from openpromptu.prompts import ManualVerbalizer
|
||||||
|
from openpromptu.prompts import ManualTemplate
|
||||||
|
from openpromptu import TokenizerWrapper
|
||||||
|
template = ManualTemplate(text = task.templates_text[template_id])
|
||||||
|
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||||
|
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||||
|
return template, verbalizer, tokenizer_wrapper
|
||||||
|
|
||||||
|
class DataCollator(HfDataCollatorMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.return_tensors='pt'
|
||||||
|
|
||||||
|
def torch_call(self, features):
|
||||||
|
return torch_default_data_collator(features=features)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_backbone(model_args, **kwargs):
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
config.dropout_rate = 0.0
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
model = AutoModelForMaskedLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
|
return config, tokenizer, model
|
||||||
|
|
||||||
|
class Trainer(HfTrainer):
|
||||||
|
def __init__(self, verbalizer=None, eval_task=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.verbalizer=verbalizer
|
||||||
|
self.eval_task=eval_task
|
||||||
|
self.compute_metrics = self._compute_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
labels = inputs.pop('labels')
|
||||||
|
outputs = model(**inputs)
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
input_ids = inputs['input_ids']
|
||||||
|
verbalizer = self.verbalizer.cuda()
|
||||||
|
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
||||||
|
label_logits = verbalizer.process_logits(logits_at_mask)
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(label_logits, labels)
|
||||||
|
outputs.logits = label_logits
|
||||||
|
return (loss, outputs) if return_outputs else loss
|
||||||
|
|
||||||
|
def _compute_metrics(self, eval_preds):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
|
||||||
|
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
||||||
|
|
||||||
|
preds = np.argmax(preds, axis=-1)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
average_metrics = []
|
||||||
|
for metric in self.eval_task.metric:
|
||||||
|
metric_item = metric(preds, labels)
|
||||||
|
metric_value = list(metric_item.values())
|
||||||
|
result.update(metric_item)
|
||||||
|
average_metrics.extend(metric_value)
|
||||||
|
print("average:",average_metrics)
|
||||||
|
average_metric = sum(average_metrics)/len(average_metrics)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
|
||||||
|
from openpromptu.data_utils import InputExample
|
||||||
|
from transformers import Seq2SeqTrainer as HfSeq2SeqTrainer
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
from transformers.data.data_collator import DataCollatorForSeq2Seq as DataCollator
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def mask_token_func(tokenizer, ith_mask):
|
||||||
|
return tokenizer.additional_special_tokens[ith_mask]
|
||||||
|
|
||||||
|
def get_remove_columns(dataset_features):
|
||||||
|
return dataset_features
|
||||||
|
|
||||||
|
def preprocess_function(raw_example, **kwargs):
|
||||||
|
# max_target_length += 1
|
||||||
|
tokenizer = kwargs['tokenizer']
|
||||||
|
data_args = kwargs['data_args']
|
||||||
|
template = kwargs['template']
|
||||||
|
verbalizer = kwargs['verbalizer']
|
||||||
|
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
||||||
|
split = kwargs['split']
|
||||||
|
example = InputExample(**raw_example)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
example = verbalizer.wrap_one_example(example)
|
||||||
|
example, other = template.wrap_one_example(example)
|
||||||
|
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
||||||
|
model_inputs = tokenizer(input_sentence, max_length=256,
|
||||||
|
padding="max_length", truncation=True)
|
||||||
|
except:
|
||||||
|
from IPython import embed; embed(header="Therer")
|
||||||
|
|
||||||
|
with tokenizer.as_target_tokenizer():
|
||||||
|
label = tokenizer(other['tgt_text']).input_ids
|
||||||
|
|
||||||
|
model_inputs["labels"] = label
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def get_backbone(model_args, **kwargs):
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
config.dropout_rate = 0.0
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
|
config=config,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
|
)
|
||||||
|
return config, tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompts(task, tokenizer, data_args, template_id="0", verbalizer_id="0"):
|
||||||
|
from openpromptu.prompts import GenerationVerbalizer
|
||||||
|
from openpromptu.prompts import ManualTemplate
|
||||||
|
from openpromptu import TokenizerWrapper
|
||||||
|
template = ManualTemplate(text = task.templates_text[template_id])
|
||||||
|
verbalizer = GenerationVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
||||||
|
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
||||||
|
return template, verbalizer, tokenizer_wrapper
|
||||||
|
|
||||||
|
class Trainer(HfSeq2SeqTrainer):
|
||||||
|
def __init__(self, verbalizer=None, eval_task=None, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.eval_task = eval_task
|
||||||
|
self.compute_metrics = self._compute_metrics
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if return_outputs:
|
||||||
|
return (outputs.loss, outputs)
|
||||||
|
else:
|
||||||
|
return outputs.loss
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model, #nn.Module,
|
||||||
|
inputs, #Dict[str, Union[torch.Tensor, Any]],
|
||||||
|
prediction_loss_only, #: bool,
|
||||||
|
ignore_keys, #: Optional[List[str]] = None,
|
||||||
|
): #-> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (:obj:`nn.Module`):
|
||||||
|
The model to evaluate.
|
||||||
|
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||||
|
The inputs and targets of the model.
|
||||||
|
|
||||||
|
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||||
|
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||||
|
prediction_loss_only (:obj:`bool`):
|
||||||
|
Whether or not to return the loss only.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
||||||
|
labels (each being optional).
|
||||||
|
"""
|
||||||
|
if not self.args.predict_with_generate or prediction_loss_only:
|
||||||
|
return super().prediction_step(
|
||||||
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
has_labels = "labels" in inputs
|
||||||
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
gen_kwargs = {
|
||||||
|
"max_length": 10, # self._max_length if s is not None else self.model.config.max_length,
|
||||||
|
"num_beams": 1 #self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
||||||
|
}
|
||||||
|
generated_tokens = self.model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
**gen_kwargs,
|
||||||
|
)
|
||||||
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
outputs = model(**inputs)
|
||||||
|
if has_labels:
|
||||||
|
if self.label_smoother is not None:
|
||||||
|
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
||||||
|
else:
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if self.args.prediction_loss_only:
|
||||||
|
return (loss, None, None)
|
||||||
|
|
||||||
|
labels = inputs["labels"]
|
||||||
|
if labels.shape[-1] < gen_kwargs["max_length"]:
|
||||||
|
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
||||||
|
|
||||||
|
# from IPython import embed; embed(header="In seqseqtrainer")
|
||||||
|
return (loss, generated_tokens, labels)
|
||||||
|
|
||||||
|
def _compute_metrics(self, eval_preds):
|
||||||
|
# from IPython import embed; embed(header="In compute metrics")
|
||||||
|
preds, labels = eval_preds
|
||||||
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
# post_processor = .get(data_args.dataset_name[0], tokenizer,
|
||||||
|
# data_args.ignore_pad_token_for_loss)
|
||||||
|
# decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
|
||||||
|
result = {}
|
||||||
|
for metric in self.eval_task.metric:
|
||||||
|
result.update(metric(decoded_preds, decoded_labels))
|
||||||
|
|
||||||
|
average_metric = sum(result.values())/len(result)
|
||||||
|
result.update({"average_metrics":average_metric})
|
||||||
|
return result
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
# import numpy as np
|
|
||||||
# from dataclasses import dataclass
|
|
||||||
# from transformers import DataCollatorForSeq2Seq
|
|
||||||
|
|
||||||
|
|
||||||
# @dataclass
|
|
||||||
# class TaskDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|
||||||
# def check_uniqueness(self, samples):
|
|
||||||
# assert len(np.unique(samples)) == 1
|
|
||||||
|
|
||||||
# def __call__(self, features):
|
|
||||||
# # tasks = [d.pop('task') for d in features]
|
|
||||||
# # self.check_uniqueness(tasks)
|
|
||||||
# output = super().__call__(features)
|
|
||||||
# # output["task"] = tasks[0]
|
|
||||||
# return output
|
|
||||||
|
|
||||||
# # class CustomDataCollator(DefaultDataCollator):
|
|
||||||
# # def check_uniqueness(self, samples):
|
|
||||||
# # assert len(np.unique(samples)) == 1
|
|
||||||
|
|
||||||
# # def __call__(self, features):
|
|
||||||
# # mask_positions = [d.pop('mask_positions') for d in features]
|
|
||||||
# # # self.check_uniqueness(tasks)
|
|
||||||
# # output = super().__call__(features)
|
|
||||||
|
|
||||||
# # # output["task"] = tasks[0]
|
|
||||||
# # return output
|
|
|
@ -1,67 +0,0 @@
|
||||||
import abc
|
|
||||||
from collections import OrderedDict
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
"""Defines functions to process the outputs to make them ready for the evaluation."""
|
|
||||||
|
|
||||||
def string_to_float(string, default=-1., **unused_kwargs):
|
|
||||||
"""Converts string to float, using default when conversion not possible."""
|
|
||||||
try:
|
|
||||||
return float(string)
|
|
||||||
except ValueError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
class PostProcessor(abc.ABC):
|
|
||||||
"""Postprocess the predictions and labels to make them suitable for
|
|
||||||
evaluation."""
|
|
||||||
def __init__(self, tokenizer, ignore_pad_token_for_loss):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.ignore_pad_token_for_loss = ignore_pad_token_for_loss
|
|
||||||
|
|
||||||
|
|
||||||
def process(self, preds, labels, data_info=None):
|
|
||||||
if isinstance(preds, tuple):
|
|
||||||
preds = preds[0]
|
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
if self.ignore_pad_token_for_loss:
|
|
||||||
# Replace -100 in the labels as we can't decode them.
|
|
||||||
labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
|
|
||||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
# Some simple post-processing
|
|
||||||
decoded_preds = [pred.strip() for pred in decoded_preds]
|
|
||||||
decoded_labels = [label.strip() for label in decoded_labels]
|
|
||||||
return decoded_preds, decoded_labels
|
|
||||||
|
|
||||||
|
|
||||||
class MultiRC(PostProcessor):
|
|
||||||
def process(self, preds, labels, data_info):
|
|
||||||
preds, labels = super().process(preds, labels, data_info)
|
|
||||||
preds = [{"group": info["group"], "value":pred} \
|
|
||||||
for info, pred in zip(data_info, preds)]
|
|
||||||
labels = [{"group": info["group"], "value": label}\
|
|
||||||
for info, label in zip(data_info, labels)]
|
|
||||||
return preds, labels
|
|
||||||
|
|
||||||
class Record(PostProcessor):
|
|
||||||
def process(self, preds, labels, data_info):
|
|
||||||
preds, labels = super().process(preds, labels, data_info)
|
|
||||||
labels = [info["answers"] for info in data_info]
|
|
||||||
return preds, labels
|
|
||||||
|
|
||||||
|
|
||||||
POSTPROCESSOR_MAPPING = OrderedDict(
|
|
||||||
[
|
|
||||||
('superglue-record', Record),
|
|
||||||
('superglue-multirc', MultiRC)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
class AutoPostProcessor:
|
|
||||||
@classmethod
|
|
||||||
def get(self, task, tokenizer, ignore_pad_token_for_loss):
|
|
||||||
if task in POSTPROCESSOR_MAPPING:
|
|
||||||
return POSTPROCESSOR_MAPPING[task](tokenizer, ignore_pad_token_for_loss)
|
|
||||||
return PostProcessor(tokenizer, ignore_pad_token_for_loss)
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
import abc
|
||||||
|
from typing import Callable, List, Mapping, Dict
|
||||||
|
import datasets
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractTask(abc.ABC):
|
||||||
|
name = NotImplemented
|
||||||
|
config = NotImplemented
|
||||||
|
prefix = NotImplemented
|
||||||
|
metric = NotImplemented
|
||||||
|
metric_names = NotImplemented
|
||||||
|
split_map = None
|
||||||
|
labels_list = None
|
||||||
|
split_to_data_split: Mapping[str, str] = \
|
||||||
|
{"train": "train", "validation": "validation", "test": "test"}
|
||||||
|
split_valid_to_make_test = True
|
||||||
|
split_train_to_make_test = False
|
||||||
|
keep_fields_after_preprocess = ["label"] # The fields that should be kept even after preprocessiing
|
||||||
|
|
||||||
|
def __init__(self, config, data_args, seed=42, default_max_length=1):
|
||||||
|
self.config = config
|
||||||
|
self.seed = seed
|
||||||
|
self.data_args = data_args
|
||||||
|
|
||||||
|
self.default_max_length = default_max_length
|
||||||
|
|
||||||
|
def check_n_obs(self, n_obs, total_size):
|
||||||
|
if n_obs is not None and n_obs > total_size:
|
||||||
|
n_obs = total_size
|
||||||
|
logger.warning("n_obs is set to %s", n_obs)
|
||||||
|
return n_obs
|
||||||
|
|
||||||
|
def shuffled_indices(self, dataset):
|
||||||
|
num_samples = len(dataset)
|
||||||
|
generator = torch.Generator()
|
||||||
|
generator.manual_seed(self.seed)
|
||||||
|
return torch.randperm(num_samples, generator=generator).tolist()
|
||||||
|
|
||||||
|
def subsample(self, dataset, n_obs=None, indices=None):
|
||||||
|
"""
|
||||||
|
Given a dataset returns the subsampled dataset.
|
||||||
|
:param n_obs: the number of samples of the subsampled dataset.
|
||||||
|
:param indices: indices to select the samples from, if not given, indices are computed
|
||||||
|
from by shuffling the given dataset.
|
||||||
|
:return: subsampled dataset.
|
||||||
|
"""
|
||||||
|
num_samples = len(dataset)
|
||||||
|
n_obs = self.check_n_obs(n_obs, num_samples)
|
||||||
|
if indices is None:
|
||||||
|
indices = self.shuffled_indices(dataset)
|
||||||
|
indices = indices[:n_obs]
|
||||||
|
return dataset.select(indices)
|
||||||
|
|
||||||
|
def load_dataset(self, split: int):
|
||||||
|
return datasets.load_dataset(self.name, self.config, split=split, script_version="master")
|
||||||
|
|
||||||
|
def get_split_indices(self, split, dataset, validation_size):
|
||||||
|
indices = self.shuffled_indices(dataset)
|
||||||
|
if split == "validation":
|
||||||
|
return indices[:validation_size]
|
||||||
|
else:
|
||||||
|
return indices[validation_size:]
|
||||||
|
|
||||||
|
def preprocessor(self, example):
|
||||||
|
return example
|
||||||
|
|
||||||
|
def get(self, split, n_obs=None, split_validation_test=False):
|
||||||
|
# For small datasets (n_samples < 10K) without test set, we divide validation set to
|
||||||
|
# half, use one half as test set and one half as validation set.
|
||||||
|
if split in ["eval", "dev", "valid"]:
|
||||||
|
split = "validation"
|
||||||
|
if split_validation_test and self.split_valid_to_make_test \
|
||||||
|
and split != "train":
|
||||||
|
mapped_split = self.split_to_data_split["validation"]
|
||||||
|
dataset = self.load_dataset(split=mapped_split)
|
||||||
|
indices = self.get_split_indices(split, dataset, validation_size=len(dataset)//2)
|
||||||
|
dataset = self.subsample(dataset, n_obs, indices)
|
||||||
|
# For larger datasets (n_samples > 10K), we divide training set into 1K as
|
||||||
|
# validation and the rest as training set, keeping the original validation
|
||||||
|
# set as the test set.
|
||||||
|
elif split_validation_test and self.split_train_to_make_test \
|
||||||
|
and split != "test":
|
||||||
|
dataset = self.load_dataset(split="train")
|
||||||
|
indices = self.get_split_indices(split, dataset, validation_size=1000)
|
||||||
|
dataset = self.subsample(dataset, n_obs, indices)
|
||||||
|
else:
|
||||||
|
mapped_split = self.split_to_data_split[split]
|
||||||
|
dataset = self.load_dataset(split=mapped_split)
|
||||||
|
# shuffles the data and samples it.
|
||||||
|
if n_obs is not None:
|
||||||
|
dataset = self.subsample(dataset, n_obs)
|
||||||
|
return dataset.map(self.preprocessor)
|
|
@ -4,7 +4,7 @@ import abc
|
||||||
import functools
|
import functools
|
||||||
from selectors import EpollSelector
|
from selectors import EpollSelector
|
||||||
from typing import Callable, List, Mapping
|
from typing import Callable, List, Mapping
|
||||||
from examples_prompt.trainers.trainer_utils import pad_punctuation
|
from .utils import pad_punctuation
|
||||||
from examples_prompt.metrics import metrics
|
from examples_prompt.metrics import metrics
|
||||||
from .utils import round_stsb_target
|
from .utils import round_stsb_target
|
||||||
import datasets
|
import datasets
|
||||||
|
@ -30,281 +30,8 @@ from collections import defaultdict
|
||||||
from openprompt.utils import round_list
|
from openprompt.utils import round_list
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# class MLMTokenizerWrapper:
|
|
||||||
# def __init__(self, max_seq_length, tokenizer, truncate_method, mask_token_func=lambda i: "<mask>"):
|
|
||||||
# self.max_seq_length=max_seq_length
|
|
||||||
# self.tokenizer=tokenizer
|
|
||||||
# self.num_special_tokens_to_add = len(tokenizer("")['input_ids'])
|
|
||||||
# # from IPython import embed; embed(header="Truega")
|
|
||||||
# self.truncate_method=truncate_method
|
|
||||||
# self.total_passed_sentences = 0
|
|
||||||
# self.num_truncated_sentences = 0
|
|
||||||
# self.mask_token_func = mask_token_func
|
|
||||||
|
|
||||||
# if truncate_method=='tail':
|
from .processor import AbstractTask
|
||||||
# self.truncate_fct = self.truncate_from_tail
|
|
||||||
# elif truncate_method=='head':
|
|
||||||
# self.truncate_fct = self.truncate_from_head
|
|
||||||
# elif truncate_method == 'balanced':
|
|
||||||
# self.truncate_fct = self.balanced_truncate
|
|
||||||
# else:
|
|
||||||
# raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# def merge_wrapped_example(self, wrapped_example,):
|
|
||||||
# ''' # TODO doens't consider the situation that input has two parts
|
|
||||||
# '''
|
|
||||||
|
|
||||||
# wrapped_example
|
|
||||||
|
|
||||||
# # for some dataset like SuperGLUE.COPA, the answer requires prediction an span of
|
|
||||||
# # the input. Or in generation tasks, we need to generate a piece of target_text.
|
|
||||||
# # In these case, it tokenized to the encoded_tgt_text for furture use.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# encoder_inputs = defaultdict(list)
|
|
||||||
# # from IPython import embed; embed(header="Line 67")
|
|
||||||
|
|
||||||
# mask_count = 0
|
|
||||||
# for piece in wrapped_example:
|
|
||||||
# if piece['text'] == "<mask>":
|
|
||||||
# encode_text = self.tokenizer.encode(self.mask_token_func(mask_count), add_special_tokens=False, return_special_tokens_mask=True )
|
|
||||||
# mask_count += 1
|
|
||||||
# else:
|
|
||||||
# encode_text = self.tokenizer.encode(piece['text'], add_special_tokens=False, return_special_tokens_mask=True )
|
|
||||||
# encoder_inputs['input_ids'].append(encode_text)
|
|
||||||
# encoder_inputs['shortenable_ids'].append([piece['shortenable_ids']] * len(encode_text))
|
|
||||||
|
|
||||||
|
|
||||||
# encoder_inputs = self.truncate(encoder_inputs=encoder_inputs)
|
|
||||||
# encoder_inputs.pop("shortenable_ids")
|
|
||||||
# encoder_inputs = self.concate_parts(input_dict=encoder_inputs)
|
|
||||||
# decoded_inputs = self.tokenizer.decode(encoder_inputs['input_ids'], clean_up_tokenization_spaces=False)
|
|
||||||
|
|
||||||
# return decoded_inputs
|
|
||||||
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def balanced_truncate(input_dict: Dict,
|
|
||||||
# num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
# '''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
|
|
||||||
# '''
|
|
||||||
# shortenable_lens = [len(parts) if parts[0]==1 else 0
|
|
||||||
# for parts in input_dict['shortenable_ids']]
|
|
||||||
# total_shortenable_len = sum(shortenable_lens)
|
|
||||||
# num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
|
|
||||||
# for part_len in shortenable_lens]
|
|
||||||
# round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)
|
|
||||||
|
|
||||||
# truncated_example = defaultdict(list)
|
|
||||||
# for key in input_dict:
|
|
||||||
# parts = input_dict[key]
|
|
||||||
# for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
|
|
||||||
# truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
|
|
||||||
# return truncated_example
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def truncate_from_tail(input_dict: Dict,
|
|
||||||
# num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
# r"""truncate the inputs from the rear
|
|
||||||
# """
|
|
||||||
# truncated_example = defaultdict(list)
|
|
||||||
# shortenable_ids = input_dict['shortenable_ids']
|
|
||||||
|
|
||||||
# for key in input_dict:
|
|
||||||
# parts = input_dict[key]
|
|
||||||
# to_trunc = num_tokens_to_truncate
|
|
||||||
# for i, part in enumerate(parts[::-1]):
|
|
||||||
# if len(part) == 0: # to prevent some part are empty after tokenization
|
|
||||||
# continue
|
|
||||||
# if shortenable_ids[-1-i][0]==0: # ==0 means the part is not shortenable
|
|
||||||
# continue
|
|
||||||
# parts[-1-i] = part[:-to_trunc] if to_trunc<len(part) else []
|
|
||||||
# to_trunc -= len(part)
|
|
||||||
# if to_trunc <= 0:
|
|
||||||
# break
|
|
||||||
# truncated_example[key] = parts
|
|
||||||
# return truncated_example
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def truncate_from_head(input_dict: Dict,
|
|
||||||
# num_tokens_to_truncate: int=0) -> Dict:
|
|
||||||
# r"""truncate the inputs from the head
|
|
||||||
# """
|
|
||||||
# truncated_example = defaultdict(list)
|
|
||||||
# shortenable_ids = input_dict['shortenable_ids']
|
|
||||||
# for key in input_dict:
|
|
||||||
# parts = input_dict[key]
|
|
||||||
# to_trunc = num_tokens_to_truncate
|
|
||||||
# for i, part in enumerate(parts):
|
|
||||||
# if shortenable_ids[i][0]==0: # ==0 means the part is not shortenable
|
|
||||||
# continue
|
|
||||||
# parts[i] = part[:-to_trunc] if to_trunc<len(part) else []
|
|
||||||
# to_trunc -= len(part)
|
|
||||||
# if to_trunc <= 0:
|
|
||||||
# break
|
|
||||||
# truncated_example[key] = parts
|
|
||||||
# return truncated_example
|
|
||||||
|
|
||||||
# @staticmethod
|
|
||||||
# def concate_parts(input_dict: Dict) -> Dict:
|
|
||||||
# for key in input_dict:
|
|
||||||
# input_dict[key] = list(itertools.chain(*input_dict[key]))
|
|
||||||
# return input_dict
|
|
||||||
|
|
||||||
|
|
||||||
# def truncate(self, encoder_inputs):
|
|
||||||
# total_tokens = sum([len(part) for part in encoder_inputs['input_ids']])
|
|
||||||
# num_specials = self.num_special_tokens_to_add
|
|
||||||
# # print("num_specials", num_specials)
|
|
||||||
# num_tokens_to_truncate = total_tokens - self.max_seq_length + num_specials
|
|
||||||
# self.total_passed_sentences+=1
|
|
||||||
# if num_tokens_to_truncate>0:
|
|
||||||
# self.num_truncated_sentences += 1
|
|
||||||
# if num_tokens_to_truncate > sum([len(x) for x in encoder_inputs['shortenable_ids']]):
|
|
||||||
# raise RuntimeError("num_tokens_to_truncate larger than number of shortenable tokens.")
|
|
||||||
# encoder_inputs = self.truncate_fct(input_dict=encoder_inputs,
|
|
||||||
# num_tokens_to_truncate=num_tokens_to_truncate)
|
|
||||||
# return encoder_inputs
|
|
||||||
|
|
||||||
# def tokenizer_preprocessor(self, example):
|
|
||||||
# # source, target = example
|
|
||||||
# # from IPython import embed; embed(header="Trehre2")
|
|
||||||
# label = example['label']
|
|
||||||
# guid = example['idx']
|
|
||||||
# meta = dict(example)
|
|
||||||
# meta.pop("label")
|
|
||||||
# meta.pop("idx")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # from IPython import embed; embed(header="Trehre2")
|
|
||||||
|
|
||||||
# e = InputExample(**{"meta": meta, 'label': label, 'guid': guid})
|
|
||||||
|
|
||||||
# if self.predict_with_generate:
|
|
||||||
# e = self.verbalizer.wrap_one_example(e)
|
|
||||||
# example_wrapped = self.template.wrap_one_example(e)
|
|
||||||
# encoded_sentence = self.tokenizer_wrapper.merge_wrapped_example(example_wrapped)
|
|
||||||
# print(encoded_sentence)
|
|
||||||
# if self.predict_with_generate:
|
|
||||||
# # return {"source": encoded_sentence, 'target': ', 'extra_fields':[]}
|
|
||||||
# return {"source": encoded_sentence, "label": label, 'target': '', 'extra_fields':{'dataset_name':self.name}}
|
|
||||||
# else:
|
|
||||||
# return {"source": encoded_sentence, "label": label, 'target': e.target_text, 'extra_fields':{'dataset_name':self.name}}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractTask(abc.ABC):
|
|
||||||
name = NotImplemented
|
|
||||||
config = NotImplemented
|
|
||||||
prefix = NotImplemented
|
|
||||||
metric = NotImplemented
|
|
||||||
metric_names = NotImplemented
|
|
||||||
split_map = None
|
|
||||||
labels_list = None
|
|
||||||
split_to_data_split: Mapping[str, str] = \
|
|
||||||
{"train": "train", "validation": "validation", "test": "test"}
|
|
||||||
small_datasets_without_all_splits = ["cola", "wnli", "rte", "superglue-cb", "superglue-copa", "superglue-multirc",
|
|
||||||
"superglue-wic", "superglue-wsc.fixed", "superglue-rte", "mrpc", "stsb",
|
|
||||||
"superglue-boolq", "qqp", "qnli", "superglue-record", "sst2"]
|
|
||||||
large_data_without_all_splits = [] #["qqp", "qnli", "superglue-record", "sst2"]
|
|
||||||
|
|
||||||
def __init__(self, config, data_args, seed=42, default_max_length=1):
|
|
||||||
self.config = config
|
|
||||||
self.seed = seed
|
|
||||||
self.data_args = data_args
|
|
||||||
# self.tokenizer = tokenizer
|
|
||||||
# self.predict_with_generate = predict_with_generate
|
|
||||||
self.default_max_length = default_max_length
|
|
||||||
|
|
||||||
# generation_paradigm = getattr(config, "generation_paradigm", True)
|
|
||||||
# self.prompt = PromptCollections[self.name](tid, vid, generation_paradigm)
|
|
||||||
|
|
||||||
|
|
||||||
# def get_max_target_length(self, default_max_length):
|
|
||||||
# if self.predict_with_generate:
|
|
||||||
# return -1
|
|
||||||
# else:
|
|
||||||
# return default_max_length
|
|
||||||
|
|
||||||
def check_n_obs(self, n_obs, total_size):
|
|
||||||
if n_obs is not None and n_obs > total_size:
|
|
||||||
n_obs = total_size
|
|
||||||
logger.warning("n_obs is set to %s", n_obs)
|
|
||||||
return n_obs
|
|
||||||
|
|
||||||
def shuffled_indices(self, dataset):
|
|
||||||
num_samples = len(dataset)
|
|
||||||
generator = torch.Generator()
|
|
||||||
generator.manual_seed(self.seed)
|
|
||||||
return torch.randperm(num_samples, generator=generator).tolist()
|
|
||||||
|
|
||||||
def subsample(self, dataset, n_obs=None, indices=None):
|
|
||||||
"""
|
|
||||||
Given a dataset returns the subsampled dataset.
|
|
||||||
:param n_obs: the number of samples of the subsampled dataset.
|
|
||||||
:param indices: indices to select the samples from, if not given, indices are computed
|
|
||||||
from by shuffling the given dataset.
|
|
||||||
:return: subsampled dataset.
|
|
||||||
"""
|
|
||||||
num_samples = len(dataset)
|
|
||||||
n_obs = self.check_n_obs(n_obs, num_samples)
|
|
||||||
if indices is None:
|
|
||||||
indices = self.shuffled_indices(dataset)
|
|
||||||
indices = indices[:n_obs]
|
|
||||||
return dataset.select(indices)
|
|
||||||
|
|
||||||
def load_dataset(self, split: int):
|
|
||||||
return datasets.load_dataset(self.name, self.config, split=split, script_version="master")
|
|
||||||
|
|
||||||
def get_split_indices(self, split, dataset, validation_size):
|
|
||||||
indices = self.shuffled_indices(dataset)
|
|
||||||
if split == "validation":
|
|
||||||
return indices[:validation_size]
|
|
||||||
else:
|
|
||||||
return indices[validation_size:]
|
|
||||||
|
|
||||||
def preprocessor(self, example):
|
|
||||||
return example
|
|
||||||
|
|
||||||
def get(self, split, n_obs=None, split_validation_test=False):
|
|
||||||
# For small datasets (n_samples < 10K) without test set, we divide validation set to
|
|
||||||
# half, use one half as test set and one half as validation set.
|
|
||||||
if split in ["eval", "dev", "valid"]:
|
|
||||||
split = "validation"
|
|
||||||
if split_validation_test and self.name in self.small_datasets_without_all_splits \
|
|
||||||
and split != "train":
|
|
||||||
mapped_split = self.split_to_data_split["validation"]
|
|
||||||
dataset = self.load_dataset(split=mapped_split)
|
|
||||||
indices = self.get_split_indices(split, dataset, validation_size=len(dataset)//2)
|
|
||||||
dataset = self.subsample(dataset, n_obs, indices)
|
|
||||||
# For larger datasets (n_samples > 10K), we divide training set into 1K as
|
|
||||||
# validation and the rest as training set, keeping the original validation
|
|
||||||
# set as the test set.
|
|
||||||
elif split_validation_test and self.name in self.large_data_without_all_splits \
|
|
||||||
and split != "test":
|
|
||||||
dataset = self.load_dataset(split="train")
|
|
||||||
indices = self.get_split_indices(split, dataset, validation_size=1000)
|
|
||||||
dataset = self.subsample(dataset, n_obs, indices)
|
|
||||||
else:
|
|
||||||
mapped_split = self.split_to_data_split[split]
|
|
||||||
dataset = self.load_dataset(split=mapped_split)
|
|
||||||
# shuffles the data and samples it.
|
|
||||||
if n_obs is not None:
|
|
||||||
dataset = self.subsample(dataset, n_obs)
|
|
||||||
return dataset.map(self.preprocessor)
|
|
||||||
|
|
||||||
class Squad(AbstractTask):
|
class Squad(AbstractTask):
|
||||||
name = "squad"
|
name = "squad"
|
||||||
|
@ -735,118 +462,95 @@ class SuperGLUEWIC(AbstractTask):
|
||||||
return datasets.load_dataset('super_glue', 'wic', split=split, script_version="master")
|
return datasets.load_dataset('super_glue', 'wic', split=split, script_version="master")
|
||||||
|
|
||||||
|
|
||||||
|
# class SuperGLUERecord(AbstractTask):
|
||||||
# class SuperGLUEWSCFixed(AbstractTask):
|
# """Convert ReCoRD examples to text2text examples.
|
||||||
# # source: https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py
|
# ReCoRD contains a passage, query containing a '@placeholder' string, and a set
|
||||||
# """Convert WSC examples to text2text format.
|
# of entities that are the possible values of the placeholder. Each train and
|
||||||
# WSC includes a sentence along with 2 'spans': the first denoting a noun and
|
# validation example will have a list of answers, any of which would be
|
||||||
# the other a pronoun. The 'label' specifies whether or not the pronoun is
|
# considered correct.
|
||||||
# referencing the noun. This preprocessor puts ' * ' around the noun and ' # '
|
# For example, a typical example from ReCoRD might look like
|
||||||
# around the pronoun.
|
# {
|
||||||
# For example, a typical example from WSC might look like
|
# 'passsage': 'This is the passage.',
|
||||||
# {
|
# 'query': 'A @placeholder is a bird.',
|
||||||
# 'text': 'This is a test sentence .',
|
# 'entities': ['penguin', 'potato', 'pigeon'],
|
||||||
# 'span1_text': 'test',
|
# 'answers': ['penguin', 'pigeon'],
|
||||||
# 'span1_index': 3,
|
# }
|
||||||
# 'span2_text': 'This',
|
# which this preprocessor would turn into the following two examples:
|
||||||
# 'span2_index': 0,
|
# {
|
||||||
# 'label': 0
|
# 'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
|
||||||
# }
|
# 'potato, pigeon passage: This is the passage.',
|
||||||
# This example would be transformed to
|
# 'targets': 'penguin',
|
||||||
# {
|
# }
|
||||||
# 'inputs': 'wsc text: # This # is a * test * sentence .',
|
# and
|
||||||
# 'targets': 'False'
|
# {
|
||||||
# }
|
# 'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
|
||||||
|
# 'potato, pigeon passage: This is the passage.',
|
||||||
|
# 'targets': 'pigeon',
|
||||||
|
# }
|
||||||
# """
|
# """
|
||||||
# name = "superglue-wsc.fixed"
|
# name = "superglue-record"
|
||||||
# labels_list = ['0', '1']
|
|
||||||
# split_to_data_split = {"train": "train",
|
# split_to_data_split = {"train": "train",
|
||||||
# "validation": "validation",
|
# "validation": "validation",
|
||||||
# "test": "validation"}
|
# "test": "validation"}
|
||||||
# metric = [metrics.accuracy]
|
# metric = [metrics.squad]
|
||||||
# metric_names = ["accuracy"]
|
# metric_names = ["squad"]
|
||||||
|
|
||||||
# def load_dataset(self, split):
|
# def load_dataset(self, split):
|
||||||
# return datasets.load_dataset('super_glue', 'wsc.fixed', split=split, script_version="master")
|
# return datasets.load_dataset('super_glue', 'record', split=split, script_version="master")
|
||||||
|
|
||||||
# def _mark_span(self, text, span_str, span_idx, mark):
|
# def preprocessor(self, batch, add_prefix=True):
|
||||||
# pattern_tmpl = r'^((?:\S+\s){N})(W)'
|
# new_batch = collections.defaultdict(list)
|
||||||
# pattern = re.sub('N', str(span_idx), pattern_tmpl)
|
# keys = batch.keys()
|
||||||
# pattern = re.sub('W', span_str, pattern)
|
# for values in zip(*batch.values()):
|
||||||
# return re.sub(pattern, r'\1{0} \2 {0}'.format(mark), text)
|
# ex = {k: v for k, v in zip(keys, values)}
|
||||||
|
# # updates the passage.
|
||||||
|
# passage = ex['passage']
|
||||||
|
# passage = re.sub(r'(\.|\?|\!|\"|\')\n@highlight\n', r'\1 ', passage)
|
||||||
|
# passage = re.sub(r'\n@highlight\n', '. ', passage)
|
||||||
|
# inputs = f"record query: {ex['query']} entities: {', '.join(ex['entities'])} passage: {passage}"
|
||||||
|
# if add_prefix:
|
||||||
|
# inputs = self.name + " " + inputs
|
||||||
|
# # duplicates the samples based on number of answers.
|
||||||
|
# num_answers = len(ex["answers"])
|
||||||
|
# num_duplicates = np.maximum(1, num_answers)
|
||||||
|
# new_batch["source"].extend([inputs] * num_duplicates)
|
||||||
|
# new_batch["target"].extend(ex["answers"] if num_answers > 0 else ["<unk>"])
|
||||||
|
# new_batch["task"].extend([self.name] * num_duplicates)
|
||||||
|
# new_batch["extra_fields"].extend([{"answers": ex["answers"]}]*num_duplicates)
|
||||||
|
# return new_batch
|
||||||
|
|
||||||
# def preprocessor(self, example, add_prefix=True):
|
# def map_dataset(self, dataset, add_prefix=True):
|
||||||
# # converts text as done in T5.
|
# return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
|
||||||
# text = example['text']
|
# batched=True, remove_columns=dataset.column_names)
|
||||||
# text = self._mark_span(text, example['span1_text'], example['span1_index'], '*')
|
|
||||||
# # Compensate for 2 added "words" added in previous step.
|
|
||||||
# span2_index = example['span2_index'] + 2 * int(example['span1_index'] < example['span2_index'])
|
|
||||||
# text = self._mark_span(text, example['span2_text'], span2_index, '#')
|
|
||||||
# src_texts = ["text:", text]
|
|
||||||
# tgt_texts = [str(example["label"])]
|
|
||||||
# return self.fseq2seq_format(src_texts, tgt_texts, add_prefix)
|
|
||||||
|
|
||||||
|
class Beans(AbstractTask):
|
||||||
class SuperGLUERecord(AbstractTask):
|
name = "beans"
|
||||||
"""Convert ReCoRD examples to text2text examples.
|
labels_list = ['angular_leaf_spot', 'bean_rust', "healthy"]
|
||||||
ReCoRD contains a passage, query containing a '@placeholder' string, and a set
|
|
||||||
of entities that are the possible values of the placeholder. Each train and
|
|
||||||
validation example will have a list of answers, any of which would be
|
|
||||||
considered correct.
|
|
||||||
For example, a typical example from ReCoRD might look like
|
|
||||||
{
|
|
||||||
'passsage': 'This is the passage.',
|
|
||||||
'query': 'A @placeholder is a bird.',
|
|
||||||
'entities': ['penguin', 'potato', 'pigeon'],
|
|
||||||
'answers': ['penguin', 'pigeon'],
|
|
||||||
}
|
|
||||||
which this preprocessor would turn into the following two examples:
|
|
||||||
{
|
|
||||||
'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
|
|
||||||
'potato, pigeon passage: This is the passage.',
|
|
||||||
'targets': 'penguin',
|
|
||||||
}
|
|
||||||
and
|
|
||||||
{
|
|
||||||
'inputs': 'record query: A @placeholder is a bird. entities: penguin, '
|
|
||||||
'potato, pigeon passage: This is the passage.',
|
|
||||||
'targets': 'pigeon',
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
name = "superglue-record"
|
|
||||||
split_to_data_split = {"train": "train",
|
split_to_data_split = {"train": "train",
|
||||||
"validation": "validation",
|
"validation": "validation",
|
||||||
"test": "validation"}
|
"test": "validation"}
|
||||||
metric = [metrics.squad]
|
metric = [metrics.accuracy]
|
||||||
metric_names = ["squad"]
|
metric_names = ["accuracy"]
|
||||||
|
|
||||||
|
verbalizers = {
|
||||||
|
"0": {
|
||||||
|
"0": "No",
|
||||||
|
"1": "Yes",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
templates_text = {
|
||||||
|
"0": """{"meta":"sentence1"}"""
|
||||||
|
}
|
||||||
|
|
||||||
def load_dataset(self, split):
|
def load_dataset(self, split):
|
||||||
return datasets.load_dataset('super_glue', 'record', split=split, script_version="master")
|
# from IPython import embed; embed(header="beans")
|
||||||
|
if self.data_args.datasets_load_from_disk:
|
||||||
|
return datasets.load_from_disk(f"{self.data_args.datasets_saved_path}/super_glue.wic")[split]
|
||||||
|
else:
|
||||||
|
return datasets.load_dataset('beans', split=split, script_version="master")
|
||||||
|
|
||||||
def preprocessor(self, batch, add_prefix=True):
|
|
||||||
new_batch = collections.defaultdict(list)
|
|
||||||
keys = batch.keys()
|
|
||||||
for values in zip(*batch.values()):
|
|
||||||
ex = {k: v for k, v in zip(keys, values)}
|
|
||||||
# updates the passage.
|
|
||||||
passage = ex['passage']
|
|
||||||
passage = re.sub(r'(\.|\?|\!|\"|\')\n@highlight\n', r'\1 ', passage)
|
|
||||||
passage = re.sub(r'\n@highlight\n', '. ', passage)
|
|
||||||
inputs = f"record query: {ex['query']} entities: {', '.join(ex['entities'])} passage: {passage}"
|
|
||||||
if add_prefix:
|
|
||||||
inputs = self.name + " " + inputs
|
|
||||||
# duplicates the samples based on number of answers.
|
|
||||||
num_answers = len(ex["answers"])
|
|
||||||
num_duplicates = np.maximum(1, num_answers)
|
|
||||||
new_batch["source"].extend([inputs] * num_duplicates)
|
|
||||||
new_batch["target"].extend(ex["answers"] if num_answers > 0 else ["<unk>"])
|
|
||||||
new_batch["task"].extend([self.name] * num_duplicates)
|
|
||||||
new_batch["extra_fields"].extend([{"answers": ex["answers"]}]*num_duplicates)
|
|
||||||
return new_batch
|
|
||||||
|
|
||||||
def map_dataset(self, dataset, add_prefix=True):
|
|
||||||
return dataset.map(functools.partial(self.preprocessor, add_prefix=add_prefix),
|
|
||||||
batched=True, remove_columns=dataset.column_names)
|
|
||||||
|
|
||||||
|
|
||||||
TASK_MAPPING = OrderedDict(
|
TASK_MAPPING = OrderedDict(
|
||||||
|
@ -866,8 +570,8 @@ TASK_MAPPING = OrderedDict(
|
||||||
('superglue-copa', SuperGLUECOPA),
|
('superglue-copa', SuperGLUECOPA),
|
||||||
('superglue-multirc', SuperGLUEMultiRC),
|
('superglue-multirc', SuperGLUEMultiRC),
|
||||||
('superglue-wic', SuperGLUEWIC),
|
('superglue-wic', SuperGLUEWIC),
|
||||||
# ('superglue-wsc.fixed', SuperGLUEWSCFixed),
|
# ('superglue-record', SuperGLUERecord)
|
||||||
('superglue-record', SuperGLUERecord)
|
('beans', Beans)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
|
||||||
def round_stsb_target(label):
|
def round_stsb_target(label):
|
||||||
"""STSB maps two sentences to a floating point number between 1 and 5
|
"""STSB maps two sentences to a floating point number between 1 and 5
|
||||||
|
@ -15,3 +16,15 @@ def round_stsb_target(label):
|
||||||
"""
|
"""
|
||||||
return np.round((label * 5) / 5, decimals=1)
|
return np.round((label * 5) / 5, decimals=1)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_punctuation(text):
|
||||||
|
"""Re-implementation of _pad_punctuation in t5. This function adds spaces
|
||||||
|
around punctuation. While this pads punctuation as expected, it has the
|
||||||
|
unexpected effected of padding certain unicode characters with accents, with
|
||||||
|
spaces as well. For instance: "François" becomes "Fran ç ois"""
|
||||||
|
# Pad everything except for: underscores (_), whitespace (\s),
|
||||||
|
# numbers (\p{N}), letters (\p{L}) and accent characters (\p{M}).
|
||||||
|
text = re.sub(r'([^_\s\p{N}\p{L}\p{M}])', r' \1 ', text)
|
||||||
|
# Collapse consecutive whitespace into one space.
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
return text
|
|
@ -0,0 +1,44 @@
|
||||||
|
|
||||||
|
PATHBASE=/mnt/sfs_turbo/hsd/officialod/OpenDelta-1/examples/examples_prompt/
|
||||||
|
PYTHONPATH=/mnt/sfs_turbo/zhangshudan/anaconda3/envs/officialod/bin/python
|
||||||
|
PLMPATHBASE=/mnt/sfs_turbo/hsd/plm_cache/ # must be empty string or dir that ends with /
|
||||||
|
DATASETSPATHBASE=/mnt/sfs_turbo/hsd/huggingface_datasets/saved_to_disk/
|
||||||
|
RUNTIME=$(date +%m%d%H%M%S)
|
||||||
|
MODELNAME="roberta-base"
|
||||||
|
DATASET=$1
|
||||||
|
DELTATYPES=("none" "bitfit" "lora" "adapter")
|
||||||
|
CUDAIDS=("0 1" "2 3" "4 5" "6 7")
|
||||||
|
NUMTRIALS=50
|
||||||
|
CONTINUESTUDY=${2:-'0'}
|
||||||
|
|
||||||
|
echo $RUNTIME
|
||||||
|
echo $MODELNAME
|
||||||
|
echo $DATASET
|
||||||
|
echo $DELTATYPE
|
||||||
|
echo $CUDAIDS
|
||||||
|
echo $NUMTRIALS
|
||||||
|
echo $CONTINUESTUDY
|
||||||
|
cd $PATHBASE
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for expid in 0 1 2 3
|
||||||
|
do
|
||||||
|
( $PYTHONPATH search_distributed.py \
|
||||||
|
--model_name $MODELNAME \
|
||||||
|
--dataset $DATASET \
|
||||||
|
--delta_type ${DELTATYPES[$expid]} \
|
||||||
|
--cuda_ids ${CUDAIDS[$expid]} \
|
||||||
|
--num_trials $NUMTRIALS \
|
||||||
|
--mode run \
|
||||||
|
--repeat_time 1 \
|
||||||
|
--main_file_name run_mlm.py \
|
||||||
|
--pathbase $PATHBASE \
|
||||||
|
--pythonpath $PYTHONPATH \
|
||||||
|
--plm_path_base $PLMPATHBASE \
|
||||||
|
--datasets_saved_path $DATASETSPATHBASE \
|
||||||
|
--datasets_load_from_disk \
|
||||||
|
--continue_study $CONTINUESTUDY >>/mnt/sfs_turbo/hsd/officialod/OpenDelta-1/examples/examples_prompt/out_sfs/$RUNTIME.txt 2>&1
|
||||||
|
) &
|
||||||
|
done
|
||||||
|
wait
|
|
@ -47,7 +47,7 @@ from examples_prompt.trainers.trainer_utils import save_training_config
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from transformers.models.t5.modeling_t5 import T5Config, T5ForConditionalGeneration
|
from transformers.models.t5.modeling_t5 import T5Config, T5ForConditionalGeneration
|
||||||
from examples_prompt.trainers.model_args import ModelArguments
|
from examples_prompt.utils.args import ModelArguments
|
||||||
from examples_prompt.trainers.trainer_args import TrainingArguments, DataTrainingArguments
|
from examples_prompt.trainers.trainer_args import TrainingArguments, DataTrainingArguments
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
from examples_prompt.metrics.metrics import transform_for_generation
|
from examples_prompt.metrics.metrics import transform_for_generation
|
||||||
|
|
|
@ -1,790 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Fine-tuning the library models for sequence to sequence.
|
|
||||||
"""
|
|
||||||
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
# from opendelta.utils.delta_center import create_hub_repo_name
|
|
||||||
import torch
|
|
||||||
import os
|
|
||||||
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
|
||||||
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
from datasets import load_dataset, load_metric, concatenate_datasets
|
|
||||||
import transformers
|
|
||||||
from transformers import (
|
|
||||||
AutoConfig,
|
|
||||||
AutoModelForMaskedLM,
|
|
||||||
AutoModelForSeq2SeqLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
DataCollatorForSeq2Seq,
|
|
||||||
HfArgumentParser,
|
|
||||||
MBartTokenizer,
|
|
||||||
default_data_collator,
|
|
||||||
set_seed,
|
|
||||||
)
|
|
||||||
from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
|
||||||
# from ..seq2seq.utils import get_adapter_config
|
|
||||||
from examples_prompt.data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
|
||||||
from transformers import Seq2SeqTrainer
|
|
||||||
# from training_args import AdapterTrainingArguments
|
|
||||||
from examples_prompt.trainers.trainer_utils import save_training_config
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from transformers.models.t5.modeling_t5 import T5Config, T5ForConditionalGeneration
|
|
||||||
from examples_prompt.trainers.model_args import ModelArguments
|
|
||||||
from examples_prompt.trainers.trainer_args import TrainingArguments, DataTrainingArguments
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
from examples_prompt.metrics.metrics import transform_for_generation
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
TASK_TO_METRICS = {"mrpc": ["accuracy", "f1"],
|
|
||||||
"cola": ['matthews_correlation'],
|
|
||||||
"stsb": ['pearson', 'spearmanr'],
|
|
||||||
'sst2': ['accuracy'],
|
|
||||||
"mnli": ["accuracy"],
|
|
||||||
"mnli_mismatched": ["accuracy"],
|
|
||||||
"mnli_matched": ["accuracy"],
|
|
||||||
"qnli": ["accuracy"],
|
|
||||||
"rte": ["accuracy"],
|
|
||||||
"wnli": ["accuracy"],
|
|
||||||
"qqp": ["accuracy", "f1"],
|
|
||||||
"superglue-boolq": ["accuracy"],
|
|
||||||
"superglue-rte": ["accuracy"],
|
|
||||||
"superglue-cb": ["f1_multiclass", "accuracy"],
|
|
||||||
"superglue-copa": ["accuracy"],
|
|
||||||
"superglue-multirc": ["f1", "em"],
|
|
||||||
"superglue-wic": ["accuracy"],
|
|
||||||
"superglue-wsc.fixed": ["accuracy"],
|
|
||||||
"superglue-record": ["f1", "em"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class RemainArgHfArgumentParser(HfArgumentParser):
|
|
||||||
def parse_json_file(self, json_file: str, return_remaining_args=True ):
|
|
||||||
"""
|
|
||||||
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
|
||||||
dataclass types.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
data = json.loads(Path(json_file).read_text())
|
|
||||||
outputs = []
|
|
||||||
for dtype in self.dataclass_types:
|
|
||||||
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
||||||
inputs = {k: data.pop(k) for k in list(data.keys()) if k in keys}
|
|
||||||
obj = dtype(**inputs)
|
|
||||||
outputs.append(obj)
|
|
||||||
|
|
||||||
remain_args = argparse.ArgumentParser()
|
|
||||||
remain_args.__dict__.update(data)
|
|
||||||
if return_remaining_args:
|
|
||||||
return (*outputs, remain_args)
|
|
||||||
else:
|
|
||||||
return (*outputs,)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# See all possible arguments in src/transformers/training_args.py
|
|
||||||
# or by passing the --help flag to this script.
|
|
||||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
||||||
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
||||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
||||||
# If we pass only one argument to the script and it's the path to a json file,
|
|
||||||
# let's parse it to get our arguments.
|
|
||||||
model_args, data_args, training_args, delta_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, delta_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
||||||
|
|
||||||
|
|
||||||
print(f"{training_args.output_dir}/results.json")
|
|
||||||
# exit()
|
|
||||||
# Detecting last checkpoint.
|
|
||||||
last_checkpoint = None
|
|
||||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
|
||||||
print("#### last_checkpoint ", last_checkpoint)
|
|
||||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
|
||||||
'''
|
|
||||||
raise ValueError(
|
|
||||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
||||||
"Use --overwrite_output_dir to overcome."
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
pass
|
|
||||||
elif last_checkpoint is not None:
|
|
||||||
logger.info(
|
|
||||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
|
||||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(
|
|
||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
|
||||||
handlers=[logging.StreamHandler(sys.stdout)],
|
|
||||||
)
|
|
||||||
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
|
||||||
|
|
||||||
# Log on each process the small summary:
|
|
||||||
logger.warning(
|
|
||||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
|
||||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
|
||||||
)
|
|
||||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
||||||
if is_main_process(training_args.local_rank):
|
|
||||||
transformers.utils.logging.set_verbosity_info()
|
|
||||||
# logger.info("Training/evaluation parameters %s", training_args, model_args, data_args, delta_args)
|
|
||||||
logger.info("{}\n{}\n{}\n{}".format(training_args, model_args, data_args, delta_args))
|
|
||||||
|
|
||||||
|
|
||||||
# Set seed before initializing model.
|
|
||||||
set_seed(training_args.seed)
|
|
||||||
|
|
||||||
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
|
|
||||||
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
|
||||||
# (the dataset will be downloaded automatically from the datasets Hub).
|
|
||||||
#
|
|
||||||
# For CSV/JSON files in the summarization task, this script will use the first column for the full texts and the
|
|
||||||
# second column for the summaries (unless you specify column names for this with the `text_column` and
|
|
||||||
# `summary_column` arguments).
|
|
||||||
# For translation, only JSON files are supported, with one field named "translation" containing two keys for the
|
|
||||||
# source and target languages (unless you adapt what follows).
|
|
||||||
#
|
|
||||||
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
|
||||||
# download the dataset.
|
|
||||||
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
|
||||||
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
|
||||||
#
|
|
||||||
# Distributed training:
|
|
||||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
|
||||||
# download model & vocab.
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
|
||||||
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
config.dropout_rate = 0.0
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
||||||
config=config,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = AutoModelForMaskedLM.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
|
||||||
config=config,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
|
||||||
)
|
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if delta_args.delta_type.lower() != "none":
|
|
||||||
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
|
||||||
delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
|
||||||
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
|
|
||||||
delta_model.freeze_module(set_state_dict = True)
|
|
||||||
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
|
||||||
|
|
||||||
|
|
||||||
# model parallelize
|
|
||||||
if hasattr(training_args, "model_parallel") and training_args.model_parallel:
|
|
||||||
logger.info('parallelize model!')
|
|
||||||
model.parallelize()
|
|
||||||
|
|
||||||
data_args.dataset_name = [data_args.task_name]
|
|
||||||
data_args.eval_dataset_name = [data_args.eval_dataset_name]
|
|
||||||
data_args.test_dataset_name = [data_args.test_dataset_name]
|
|
||||||
data_args.dataset_config_name = [data_args.dataset_config_name]
|
|
||||||
data_args.eval_dataset_config_name = [data_args.eval_dataset_config_name]
|
|
||||||
data_args.test_dataset_config_name = [data_args.test_dataset_config_name]
|
|
||||||
assert len(data_args.dataset_name) == len(data_args.dataset_config_name)
|
|
||||||
if data_args.eval_dataset_name is not None:
|
|
||||||
assert len(data_args.eval_dataset_name) == len(data_args.eval_dataset_config_name)
|
|
||||||
if data_args.test_dataset_name is not None:
|
|
||||||
assert len(data_args.test_dataset_name) == len(data_args.test_dataset_config_name)
|
|
||||||
|
|
||||||
# Temporarily set max_target_length for training.
|
|
||||||
#max_target_length = data_args.max_target_length
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
column_names = ['source', 'target', 'label', 'extra_fields']
|
|
||||||
performance_metrics = {}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompts(task, tokenizer, predict_with_generate, template_id="0", verbalizer_id="0"):
|
|
||||||
# tid = getattr(config, "template_id", "0")
|
|
||||||
# vid = getattr(config, "verbalizer_id", "0")
|
|
||||||
from openpromptu.prompts import GenerationVerbalizer, ManualVerbalizer
|
|
||||||
from openpromptu.prompts import ManualTemplate
|
|
||||||
template = ManualTemplate(text = task.templates_text[template_id])
|
|
||||||
if predict_with_generate:
|
|
||||||
verbalizer = GenerationVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
|
||||||
else:
|
|
||||||
verbalizer = ManualVerbalizer(tokenizer=tokenizer, classes = task.labels_list, label_words=task.verbalizers[verbalizer_id])
|
|
||||||
# max_target_length = self.get_max_target_length(self.default_max_length)
|
|
||||||
|
|
||||||
from openpromptu import TokenizerWrapper
|
|
||||||
tokenizer_wrapper = TokenizerWrapper(max_seq_length=data_args.max_source_length, tokenizer=tokenizer, truncate_method="balanced", mask_token_func=mask_token_func)
|
|
||||||
return template, verbalizer, tokenizer_wrapper
|
|
||||||
|
|
||||||
|
|
||||||
from openpromptu.data_utils import InputExample
|
|
||||||
|
|
||||||
max_target_length = 32
|
|
||||||
|
|
||||||
if os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
|
||||||
mask_token_func = lambda i: tokenizer.additional_special_tokens[i]
|
|
||||||
def preprocess_function(raw_example, **kwargs):
|
|
||||||
# max_target_length += 1
|
|
||||||
tokenizer = kwargs['tokenizer']
|
|
||||||
data_args = kwargs['data_args']
|
|
||||||
template = kwargs['template']
|
|
||||||
verbalizer = kwargs['verbalizer']
|
|
||||||
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
|
||||||
split = kwargs['split']
|
|
||||||
# extra_fileds = example['extra_fields']
|
|
||||||
|
|
||||||
example = InputExample(**raw_example)
|
|
||||||
|
|
||||||
# from collections import namedtuple
|
|
||||||
# example['tgt_text'] = ""
|
|
||||||
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
|
||||||
try:
|
|
||||||
example = verbalizer.wrap_one_example(example)
|
|
||||||
example, other = template.wrap_one_example(example)
|
|
||||||
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
|
||||||
model_inputs = tokenizer(input_sentence, max_length=256,
|
|
||||||
padding="max_length", truncation=True)
|
|
||||||
except:
|
|
||||||
from IPython import embed; embed(header="Therer")
|
|
||||||
|
|
||||||
|
|
||||||
# if split == "train":
|
|
||||||
with tokenizer.as_target_tokenizer():
|
|
||||||
label = tokenizer(other['tgt_text']).input_ids
|
|
||||||
# label = [l if l != tokenizer.pad_token_id else -100 for l in label]
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="Therer")
|
|
||||||
model_inputs["labels"] = label
|
|
||||||
# else:
|
|
||||||
# # from IPython import embed; embed(header="Therer")
|
|
||||||
# model_inputs["tgt_text"] = other['tgt_text']
|
|
||||||
# model_inputs['labels'] = None # model_inputs["extra_fields"] = extra_fileds
|
|
||||||
# from IPython import embed; embed(header="Therer2")
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def compute_metrics(eval_preds, tokenizer, dataset_name, eval_metric):
|
|
||||||
# from IPython import embed; embed(header="In compute metrics")
|
|
||||||
preds, labels = eval_preds
|
|
||||||
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
||||||
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
# post_processor = .get(data_args.dataset_name[0], tokenizer,
|
|
||||||
# data_args.ignore_pad_token_for_loss)
|
|
||||||
# decoded_preds, decoded_labels = post_processor.process(preds, labels, data_info)
|
|
||||||
result = {}
|
|
||||||
for metric in eval_metric:
|
|
||||||
result.update(metric(decoded_preds, decoded_labels))
|
|
||||||
|
|
||||||
average_metric = sum(result.values())/len(result)
|
|
||||||
result.update({"average_metrics":average_metric})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
|
|
||||||
or os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
|
||||||
mask_token_func = lambda i: tokenizer.mask_token
|
|
||||||
def preprocess_function(raw_example, **kwargs):
|
|
||||||
# max_target_length += 1
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="Therer")
|
|
||||||
tokenizer = kwargs['tokenizer']
|
|
||||||
|
|
||||||
data_args = kwargs['data_args']
|
|
||||||
template = kwargs['template']
|
|
||||||
verbalizer = kwargs['verbalizer']
|
|
||||||
tokenizer_wrapper = kwargs['tokenizer_wrapper']
|
|
||||||
|
|
||||||
example = InputExample(**raw_example)
|
|
||||||
|
|
||||||
# from collections import namedtuple
|
|
||||||
# example['tgt_text'] = ""
|
|
||||||
# example = namedtuple("ObjectName", example.keys())(*example.values())
|
|
||||||
# try:
|
|
||||||
# example = verbalizer.wrap_one_example(example)
|
|
||||||
example, other = template.wrap_one_example(example)
|
|
||||||
input_sentence = tokenizer_wrapper.merge_wrapped_example(example)
|
|
||||||
model_inputs = tokenizer(input_sentence, max_length=256,
|
|
||||||
padding="max_length", truncation=True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# print("max_length", data_args.max_source_length)
|
|
||||||
# model_inputs = tokenizer(examples['source'], max_length=data_args.max_source_length,
|
|
||||||
# padding="max_length", truncation=True)
|
|
||||||
|
|
||||||
# mask_position = [(id, input_id.index(tokenizer.mask_token_id)) for id, input_id in enumerate(model_inputs.input_ids)]# [[-100 if i != tokenizer.mask_token_id else tokenizer.convert_tokens_to_ids(target) for i in input_id] for input_id, target in zip(model_inputs.input_ids, examples['target'])]
|
|
||||||
# model_inputs["mask_position"] = mask_position
|
|
||||||
# model_inputs["extra_fields"] = examples['extra_fields']
|
|
||||||
# from IPython import embed; embed(header="Therer")
|
|
||||||
return model_inputs
|
|
||||||
|
|
||||||
def compute_metrics(eval_preds, dataset_name):
|
|
||||||
# from IPython import embed; embed(header="In compute metrics")
|
|
||||||
|
|
||||||
preds, labels = eval_preds.predictions, eval_preds.label_ids
|
|
||||||
|
|
||||||
preds = np.argmax(preds, axis=-1)
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
average_metrics = []
|
|
||||||
for metric in eval_metric:
|
|
||||||
metric_item = metric(preds, labels)
|
|
||||||
metric_value = list(metric_item.values())
|
|
||||||
result.update(metric_item)
|
|
||||||
average_metrics.extend(metric_value)
|
|
||||||
print("average:",average_metrics)
|
|
||||||
average_metric = sum(average_metrics)/len(average_metrics)
|
|
||||||
result.update({"average_metrics":average_metric})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if training_args.do_train:
|
|
||||||
|
|
||||||
train_task = AutoTask.get(data_args.task_name,
|
|
||||||
data_args.dataset_config_name,
|
|
||||||
data_args=data_args,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# predict_with_generate=training_args.predict_with_generate,
|
|
||||||
seed=data_args.data_seed)
|
|
||||||
|
|
||||||
train_dataset = train_task.get(split='train',
|
|
||||||
split_validation_test=training_args.split_validation_test,
|
|
||||||
n_obs=data_args.max_train_samples)
|
|
||||||
|
|
||||||
template, verbalizer, tokenizer_wrapper = get_prompts(train_task, tokenizer, training_args.predict_with_generate)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
functools.partial(preprocess_function,
|
|
||||||
data_args=data_args,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
template=template,
|
|
||||||
verbalizer=verbalizer,
|
|
||||||
tokenizer_wrapper=tokenizer_wrapper,
|
|
||||||
split="train"),
|
|
||||||
batched=False,
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
remove_columns=[x for x in train_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
eval_splits_names = []
|
|
||||||
|
|
||||||
if training_args.do_eval:
|
|
||||||
eval_splits_names.append("eval")
|
|
||||||
if training_args.do_test:
|
|
||||||
eval_splits_names.append("test")
|
|
||||||
eval_splits = {}
|
|
||||||
for split_name in eval_splits_names:
|
|
||||||
eval_task = AutoTask.get(data_args.task_name,
|
|
||||||
data_args.dataset_config_name,
|
|
||||||
data_args=data_args,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# predict_with_generate=training_args.predict_with_generate,
|
|
||||||
seed=data_args.data_seed)
|
|
||||||
# for dataset_name, dataset_config_name\
|
|
||||||
# in zip(getattr(data_args,f"{split_name}_dataset_name"), getattr(data_args, f"{split_name}_dataset_config_name"))}
|
|
||||||
|
|
||||||
eval_dataset = eval_task.get(split=split_name,
|
|
||||||
split_validation_test=training_args.split_validation_test,
|
|
||||||
n_obs=data_args.max_train_samples)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template, _verbalizer, tokenizer_wrapper = get_prompts(eval_task, tokenizer, training_args.predict_with_generate)
|
|
||||||
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
functools.partial(preprocess_function,
|
|
||||||
data_args=data_args,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
template=template,
|
|
||||||
verbalizer=_verbalizer,
|
|
||||||
tokenizer_wrapper=tokenizer_wrapper,
|
|
||||||
split=split_name),
|
|
||||||
batched=False,
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
remove_columns=[x for x in eval_dataset.features if x not in ("label",)], # if train_dataset != "superglue-record" else column_names+["answers"],
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
eval_splits[split_name] = eval_dataset
|
|
||||||
if split_name == "test":
|
|
||||||
eval_metric = eval_task.metric
|
|
||||||
verbalizer = _verbalizer
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MLMTrainer(Trainer):
|
|
||||||
def __init__(self, verbalizer=None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.verbalizer=verbalizer
|
|
||||||
|
|
||||||
# def training_step(self, model, inputs):
|
|
||||||
# from IPython import embed; embed(header="in trainstep")
|
|
||||||
# return super().training_step(model, inputs)
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
|
||||||
|
|
||||||
labels = inputs.pop('labels')
|
|
||||||
# extra_fields = inputs.pop("extra_fields")
|
|
||||||
outputs = model(**inputs)
|
|
||||||
logits = outputs.get("logits")
|
|
||||||
input_ids = inputs['input_ids']
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="382")
|
|
||||||
verbalizer = self.verbalizer.cuda()
|
|
||||||
logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
|
||||||
label_logits = verbalizer.process_logits(logits_at_mask)
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
# from IPython import embed; embed(header="In compute loss")
|
|
||||||
loss = loss_fct(label_logits, labels)
|
|
||||||
outputs.logits = label_logits
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
|
|
||||||
|
|
||||||
class MySeq2SeqTrainer(Seq2SeqTrainer):
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False):
|
|
||||||
# from IPython import embed; embed(header="agag")
|
|
||||||
|
|
||||||
intlabel = inputs.pop('label')
|
|
||||||
# extra_fields = inputs.pop("extra_fields")
|
|
||||||
outputs = model(**inputs)
|
|
||||||
# logits = outputs.get("logits")
|
|
||||||
# input_ids = inputs['input_ids']
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# # from IPython import embed; embed(header="382")
|
|
||||||
# verbalizer = self._verbalizers.cuda()
|
|
||||||
# logits_at_mask = logits[torch.where(input_ids == verbalizer.tokenizer.mask_token_id)]
|
|
||||||
# label_logits = verbalizer.process_logits(logits_at_mask)
|
|
||||||
# loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
# # from IPython import embed; embed(header="In compute loss")
|
|
||||||
# loss = loss_fct(label_logits, labels)
|
|
||||||
# outputs.logits = label_logits
|
|
||||||
if return_outputs:
|
|
||||||
return (outputs.loss, outputs)
|
|
||||||
else:
|
|
||||||
return outputs.loss
|
|
||||||
|
|
||||||
|
|
||||||
# def evaluate(
|
|
||||||
# self,
|
|
||||||
# eval_dataset: Optional[Dict[str, Dataset]] = None,
|
|
||||||
# ignore_keys: Optional[List[str]] = None,
|
|
||||||
# metric_key_prefix: str = "eval",
|
|
||||||
# max_length: Optional[int] = None,
|
|
||||||
# num_beams: Optional[int] = None,
|
|
||||||
# ) -> Dict[str, float]:
|
|
||||||
# # TODO: this also needs to be set per dataset
|
|
||||||
# self._max_length = max_length
|
|
||||||
# self._num_beams = num_beams
|
|
||||||
# return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model, #nn.Module,
|
|
||||||
inputs, #Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
prediction_loss_only, #: bool,
|
|
||||||
ignore_keys, #: Optional[List[str]] = None,
|
|
||||||
): #-> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (:obj:`nn.Module`):
|
|
||||||
The model to evaluate.
|
|
||||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
||||||
The inputs and targets of the model.
|
|
||||||
|
|
||||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
||||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
||||||
prediction_loss_only (:obj:`bool`):
|
|
||||||
Whether or not to return the loss only.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
|
||||||
labels (each being optional).
|
|
||||||
"""
|
|
||||||
if not self.args.predict_with_generate or prediction_loss_only:
|
|
||||||
return super().prediction_step(
|
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
has_labels = "labels" in inputs
|
|
||||||
inputs = self._prepare_inputs(inputs)
|
|
||||||
intlabel = inputs.pop('label')
|
|
||||||
gen_kwargs = {
|
|
||||||
"max_length": 10, # self._max_length if s is not None else self.model.config.max_length,
|
|
||||||
"num_beams": 1 #self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
|
||||||
}
|
|
||||||
generated_tokens = self.model.generate(
|
|
||||||
inputs["input_ids"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
**gen_kwargs,
|
|
||||||
)
|
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
|
||||||
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
outputs = model(**inputs)
|
|
||||||
if has_labels:
|
|
||||||
if self.label_smoother is not None:
|
|
||||||
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
|
||||||
else:
|
|
||||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
|
||||||
else:
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if self.args.prediction_loss_only:
|
|
||||||
return (loss, None, None)
|
|
||||||
|
|
||||||
labels = inputs["labels"]
|
|
||||||
if labels.shape[-1] < gen_kwargs["max_length"]:
|
|
||||||
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
|
||||||
|
|
||||||
# from IPython import embed; embed(header="In seqseqtrainer")
|
|
||||||
return (loss, generated_tokens, labels)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys):
|
|
||||||
# aa = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
|
|
||||||
# # from IPython import embed; embed()
|
|
||||||
# return aa
|
|
||||||
# from transformers.data.data_collator import torch_default_data_collator , DataCollatorMixin
|
|
||||||
# class DataCollatorWithExtraFields(DataCollatorMixin):
|
|
||||||
# return_tensors: str = "pt"
|
|
||||||
# def torch_call(self, features):
|
|
||||||
# # print(len(features))
|
|
||||||
# # extra_fields = [f.pop('extra_fields') for f in features]
|
|
||||||
# batch = torch_default_data_collator(features)
|
|
||||||
# batch['extra_fields'] =extra_fields
|
|
||||||
# # print(batch['input_ids'].size())
|
|
||||||
# # print(batch['labels'].size())
|
|
||||||
# return batch
|
|
||||||
|
|
||||||
|
|
||||||
# from transformers.data.data_collator import DefaultDataCollator
|
|
||||||
# class CustomDataCollator(DefaultDataCollator):
|
|
||||||
|
|
||||||
# def __call__(self, features):
|
|
||||||
# mask_position = [d.pop('mask_position') for d in features]
|
|
||||||
# # self.check_uniqueness(tasks)
|
|
||||||
# from IPython import embed; embed(header="featurres")
|
|
||||||
# output = super().__call__(features)
|
|
||||||
# # mask_positions = [d.pop('mask_position') for d in features]
|
|
||||||
# output["mask_position"] = mask_position
|
|
||||||
# return output
|
|
||||||
|
|
||||||
|
|
||||||
training_args.remove_unused_columns = False
|
|
||||||
|
|
||||||
if os.path.basename(model_args.model_name_or_path).startswith("roberta") or \
|
|
||||||
os.path.basename(model_args.model_name_or_path).startswith("bert"):
|
|
||||||
trainer = MLMTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset if training_args.do_train else None,
|
|
||||||
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
|
||||||
compute_metrics=functools.partial(compute_metrics, dataset_name=data_args.task_name),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
# data_collator=DataCollatorWithExtraFields(),
|
|
||||||
verbalizer=verbalizer,
|
|
||||||
)
|
|
||||||
elif os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
|
||||||
trainer = MySeq2SeqTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset if training_args.do_train else None,
|
|
||||||
eval_dataset=eval_splits['eval'] if training_args.do_eval else None,
|
|
||||||
compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer, dataset_name=data_args.task_name, eval_metric=eval_metric),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Saves training config.
|
|
||||||
if trainer.is_world_process_zero():
|
|
||||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
|
||||||
save_training_config(sys.argv[1], training_args.output_dir)
|
|
||||||
|
|
||||||
# Training
|
|
||||||
if training_args.do_train:
|
|
||||||
checkpoint = None
|
|
||||||
if training_args.resume_from_checkpoint is not None:
|
|
||||||
checkpoint = training_args.resume_from_checkpoint
|
|
||||||
elif last_checkpoint is not None:
|
|
||||||
checkpoint = last_checkpoint
|
|
||||||
|
|
||||||
if training_args.compute_time:
|
|
||||||
torch.cuda.synchronize() # wait for move to complete
|
|
||||||
start = torch.cuda.Event(enable_timing=True)
|
|
||||||
end = torch.cuda.Event(enable_timing=True)
|
|
||||||
start.record()
|
|
||||||
|
|
||||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
|
||||||
|
|
||||||
if training_args.compute_time:
|
|
||||||
end.record()
|
|
||||||
torch.cuda.synchronize() # wait for all_reduce to complete
|
|
||||||
total_time = start.elapsed_time(end)/(1000*60)
|
|
||||||
performance_metrics.update({"total_time in minutes ": total_time})
|
|
||||||
|
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
|
||||||
train_metrics = train_result.metrics
|
|
||||||
max_train_samples = (
|
|
||||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
|
||||||
)
|
|
||||||
train_metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
|
||||||
trainer.log_metrics("train", train_metrics)
|
|
||||||
trainer.save_metrics("train", train_metrics)
|
|
||||||
trainer.save_state()
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and training_args.compute_memory:
|
|
||||||
peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
|
|
||||||
print(
|
|
||||||
"Memory utilization",
|
|
||||||
peak_memory,
|
|
||||||
"GB"
|
|
||||||
)
|
|
||||||
performance_metrics.update({"peak_memory": peak_memory})
|
|
||||||
if training_args.compute_memory or training_args.compute_time:
|
|
||||||
print(performance_metrics)
|
|
||||||
trainer.save_metrics("performance", performance_metrics)
|
|
||||||
|
|
||||||
# Evaluation
|
|
||||||
all_results = {}
|
|
||||||
|
|
||||||
all_results['evaluate'] = {}
|
|
||||||
|
|
||||||
if training_args.do_eval:
|
|
||||||
logger.info("*** Evaluate ***")
|
|
||||||
|
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_splits['eval'],
|
|
||||||
)
|
|
||||||
trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
|
|
||||||
trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
|
|
||||||
all_results['evaluate'][data_args.task_name] = metrics
|
|
||||||
|
|
||||||
# Test
|
|
||||||
all_results['test'] = {}
|
|
||||||
if training_args.do_test:
|
|
||||||
logger.info("*** Test ***")
|
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_splits['test'],
|
|
||||||
metric_key_prefix="test"
|
|
||||||
)
|
|
||||||
trainer.log_metrics(f"{data_args.task_name}_test", metrics)
|
|
||||||
trainer.save_metrics(f"{data_args.task_name}_test", metrics)
|
|
||||||
all_results['test'][data_args.task_name] = metrics
|
|
||||||
|
|
||||||
# repo_name = create_hub_repo_name(root="DeltaHub",
|
|
||||||
# dataset=data_args.task_name,
|
|
||||||
# delta_type = delta_args.delta_type,
|
|
||||||
# model_name_or_path= model_args.model_name_or_path)
|
|
||||||
# results['repo_name'] = repo_name
|
|
||||||
# if delta_args.delta_type.lower() != "none":
|
|
||||||
# if training_args.push_to_hub: # TODO add description here
|
|
||||||
# delta_model.save_finetuned(push_to_hub=True, save_directory=repo_name, use_auth_token=True)
|
|
||||||
# # trainer.push_to_hub(**kwargs)
|
|
||||||
# else:
|
|
||||||
# delta_model.save_finetuned(push_to_hub=False, save_directory=repo_name, use_auth_token=True)
|
|
||||||
|
|
||||||
|
|
||||||
with open(f"{training_args.output_dir}/results.json", 'w') as fout:
|
|
||||||
string = json.dumps(all_results, indent=4,sort_keys=True)
|
|
||||||
fout.write(string+"\n")
|
|
||||||
|
|
||||||
return all_results
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
result = main()
|
|
||||||
|
|
|
@ -0,0 +1,323 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright OpenDelta Team and THUNLP lab. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
A unified runing scripts for most models to do down stream tasks in a
|
||||||
|
prompt learning fashion, i.e., No classification head, all tasks are casted
|
||||||
|
to mask prediction or span prediction tasks.
|
||||||
|
|
||||||
|
Processing relevant to different backbone models are stored in ../backbones/
|
||||||
|
|
||||||
|
Adding A few lines to integrate the Delta tuning methods.
|
||||||
|
|
||||||
|
You can also adapt this script on your own tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
os.environ['MKL_THREADING_LAYER'] = 'GNU'
|
||||||
|
os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
sys.path.append(os.path.join(os.getcwd(), "../"))
|
||||||
|
sys.path.append(os.path.join(os.getcwd()))
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForMaskedLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq,
|
||||||
|
# HfArgumentParser,
|
||||||
|
# MBartTokenizer,
|
||||||
|
# default_data_collator,
|
||||||
|
Trainer,
|
||||||
|
Seq2SeqTrainer,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.trainer_utils import is_main_process, get_last_checkpoint
|
||||||
|
|
||||||
|
from data_processors import AutoTask #, #TaskDataCollatorForSeq2Seq, AutoPostProcessor, data_collator
|
||||||
|
from utils import read_json, save_json
|
||||||
|
from utils.args import ModelArguments, TrainingArguments, DataTrainingArguments, RemainArgHfArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in src/transformers/training_args.py
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||||
|
parser = RemainArgHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
model_args, data_args, training_args, delta_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
model_args, data_args, training_args, delta_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"{training_args.output_dir}/results.json")
|
||||||
|
# exit()
|
||||||
|
# Detecting last checkpoint.
|
||||||
|
last_checkpoint = None
|
||||||
|
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||||
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
print("#### last_checkpoint ", last_checkpoint)
|
||||||
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
|
'''
|
||||||
|
raise ValueError(
|
||||||
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||||
|
"Use --overwrite_output_dir to overcome."
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
elif last_checkpoint is not None:
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||||
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
|
||||||
|
|
||||||
|
# Log on each process the small summary:
|
||||||
|
logger.warning(
|
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||||
|
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||||
|
)
|
||||||
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||||
|
if is_main_process(training_args.local_rank):
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
# logger.info("Training/evaluation parameters %s", training_args, model_args, data_args, delta_args)
|
||||||
|
logger.info("{}\n{}\n{}\n{}".format(training_args, model_args, data_args, delta_args))
|
||||||
|
|
||||||
|
|
||||||
|
# Set seed before initializing model.
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if os.path.basename(model_args.model_name_or_path).startswith("t5"):
|
||||||
|
from examples_prompt.backbones.t5 import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
|
||||||
|
from examples_prompt.backbones.t5 import Trainer, DataCollator
|
||||||
|
elif os.path.basename(model_args.model_name_or_path).startswith("roberta") \
|
||||||
|
or os.path.basename(model_args.model_name_or_path).startswith("bert") \
|
||||||
|
or os.path.basename(model_args.model_name_or_path).startswith("albert") :
|
||||||
|
from examples_prompt.backbones.bert import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
|
||||||
|
from examples_prompt.backbones.bert import Trainer, DataCollator
|
||||||
|
elif os.path.basename(model_args.model_name_or_path).startswith("beit"):
|
||||||
|
from examples_prompt.backbones.beit import get_backbone, preprocess_function, mask_token_func, get_remove_columns, get_prompts
|
||||||
|
from examples_prompt.backbones.beit import Trainer, DataCollator
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
config, tokenizer, model = get_backbone(model_args=model_args)
|
||||||
|
|
||||||
|
if delta_args.delta_type.lower() != "none":
|
||||||
|
from opendelta import AutoDeltaConfig,AutoDeltaModel
|
||||||
|
delta_config = AutoDeltaConfig.from_dict(vars(delta_args))
|
||||||
|
delta_model = AutoDeltaModel.from_config(delta_config, backbone_model=model)
|
||||||
|
delta_model.freeze_module(set_state_dict = True)
|
||||||
|
delta_model.log(delta_ratio=True, trainable_ratio=True, visualization=True)
|
||||||
|
|
||||||
|
# model parallelize
|
||||||
|
if hasattr(training_args, "model_parallel") and training_args.model_parallel:
|
||||||
|
logger.info('parallelize model!')
|
||||||
|
model.parallelize()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
performance_metrics = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
non_empty_splits_names = []
|
||||||
|
if training_args.do_train:
|
||||||
|
non_empty_splits_names.append("train")
|
||||||
|
if training_args.do_eval:
|
||||||
|
non_empty_splits_names.append("eval")
|
||||||
|
if training_args.do_test:
|
||||||
|
non_empty_splits_names.append("test")
|
||||||
|
splits = {}
|
||||||
|
for split_name in ['train', 'eval', 'test']:
|
||||||
|
if split_name not in non_empty_splits_names:
|
||||||
|
splits[split_name] = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
task = AutoTask.get(data_args.task_name,
|
||||||
|
data_args.dataset_config_name,
|
||||||
|
data_args=data_args,
|
||||||
|
seed=data_args.data_seed)
|
||||||
|
|
||||||
|
dataset = task.get(split=split_name,
|
||||||
|
split_validation_test=training_args.split_validation_test,
|
||||||
|
n_obs=data_args.max_train_samples)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template, _verbalizer, tokenizer_wrapper = get_prompts(task, tokenizer, training_args)
|
||||||
|
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
functools.partial(preprocess_function,
|
||||||
|
data_args=data_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
template=template,
|
||||||
|
verbalizer=_verbalizer,
|
||||||
|
tokenizer_wrapper=tokenizer_wrapper,
|
||||||
|
split=split_name),
|
||||||
|
batched=False,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=get_remove_columns(list(dataset.features.keys())),
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
# from IPython import embed; embed()
|
||||||
|
splits[split_name] = dataset
|
||||||
|
if split_name == "eval":
|
||||||
|
eval_task = task
|
||||||
|
verbalizer = _verbalizer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
verbalizer=verbalizer,
|
||||||
|
eval_task=eval_task,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=splits['train'],
|
||||||
|
eval_dataset=splits['eval'],
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=DataCollator(tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_training_config(config_file, output_dir):
|
||||||
|
json_data = read_json(config_file)
|
||||||
|
save_json(os.path.join(output_dir, "training_config.json"), json_data)
|
||||||
|
|
||||||
|
|
||||||
|
# Saves training config.
|
||||||
|
if trainer.is_world_process_zero():
|
||||||
|
save_training_config(sys.argv[1], training_args.output_dir)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
checkpoint = None
|
||||||
|
if training_args.resume_from_checkpoint is not None:
|
||||||
|
checkpoint = training_args.resume_from_checkpoint
|
||||||
|
elif last_checkpoint is not None:
|
||||||
|
checkpoint = last_checkpoint
|
||||||
|
|
||||||
|
if training_args.compute_time:
|
||||||
|
torch.cuda.synchronize() # wait for move to complete
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
|
||||||
|
if training_args.compute_time:
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize() # wait for all_reduce to complete
|
||||||
|
total_time = start.elapsed_time(end)/(1000*60)
|
||||||
|
performance_metrics.update({"total_time in minutes ": total_time})
|
||||||
|
|
||||||
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
train_metrics = train_result.metrics
|
||||||
|
max_train_samples = (
|
||||||
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(splits['train'])
|
||||||
|
)
|
||||||
|
train_metrics["train_samples"] = min(max_train_samples, len(splits['train']))
|
||||||
|
trainer.log_metrics("train", train_metrics)
|
||||||
|
trainer.save_metrics("train", train_metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and training_args.compute_memory:
|
||||||
|
peak_memory = (torch.cuda.max_memory_allocated() / 1024 ** 2)/1000
|
||||||
|
print(
|
||||||
|
"Memory utilization",
|
||||||
|
peak_memory,
|
||||||
|
"GB"
|
||||||
|
)
|
||||||
|
performance_metrics.update({"peak_memory": peak_memory})
|
||||||
|
if training_args.compute_memory or training_args.compute_time:
|
||||||
|
print("Efficiency Statistics {}".format(performance_metrics))
|
||||||
|
trainer.save_metrics("performance", performance_metrics)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
all_results['evaluate'] = {}
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
|
metrics = trainer.evaluate(eval_dataset=splits['eval'],
|
||||||
|
)
|
||||||
|
trainer.log_metrics(f"{data_args.task_name}_eval", metrics)
|
||||||
|
trainer.save_metrics(f"{data_args.task_name}_eval", metrics)
|
||||||
|
all_results['evaluate'][data_args.task_name] = metrics
|
||||||
|
|
||||||
|
# Test
|
||||||
|
all_results['test'] = {}
|
||||||
|
if training_args.do_test:
|
||||||
|
logger.info("*** Test ***")
|
||||||
|
metrics = trainer.evaluate(eval_dataset=splits['test'],
|
||||||
|
metric_key_prefix="test"
|
||||||
|
)
|
||||||
|
trainer.log_metrics(f"{data_args.task_name}_test", metrics)
|
||||||
|
trainer.save_metrics(f"{data_args.task_name}_test", metrics)
|
||||||
|
all_results['test'][data_args.task_name] = metrics
|
||||||
|
|
||||||
|
# repo_name = create_hub_repo_name(root="DeltaHub",
|
||||||
|
# dataset=data_args.task_name,
|
||||||
|
# delta_type = delta_args.delta_type,
|
||||||
|
# model_name_or_path= model_args.model_name_or_path)
|
||||||
|
# results['repo_name'] = repo_name
|
||||||
|
# if delta_args.delta_type.lower() != "none":
|
||||||
|
# if training_args.push_to_hub: # TODO add description here
|
||||||
|
# delta_model.save_finetuned(push_to_hub=True, save_directory=repo_name, use_auth_token=True)
|
||||||
|
# # trainer.push_to_hub(**kwargs)
|
||||||
|
# else:
|
||||||
|
# delta_model.save_finetuned(push_to_hub=False, save_directory=repo_name, use_auth_token=True)
|
||||||
|
|
||||||
|
|
||||||
|
with open(f"{training_args.output_dir}/results.json", 'w') as fout:
|
||||||
|
string = json.dumps(all_results, indent=4,sort_keys=True)
|
||||||
|
fout.write(string+"\n")
|
||||||
|
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
result = main()
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
from .trainer import BaseTrainer
|
|
||||||
from .seq2seq_trainer import Seq2SeqTrainer
|
|
|
@ -1,36 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArguments:
|
|
||||||
"""
|
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
|
||||||
"""
|
|
||||||
model_name_or_path: str = field(
|
|
||||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
|
||||||
)
|
|
||||||
config_name: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
||||||
)
|
|
||||||
tokenizer_name: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
|
||||||
)
|
|
||||||
cache_dir: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
|
||||||
)
|
|
||||||
use_fast_tokenizer: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
|
||||||
)
|
|
||||||
model_revision: str = field(
|
|
||||||
default="main",
|
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
|
||||||
)
|
|
||||||
use_auth_token: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
|
||||||
"with private models)."
|
|
||||||
},
|
|
||||||
)
|
|
|
@ -1,108 +0,0 @@
|
||||||
from packaging import version
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from transformers import Seq2SeqTrainer
|
|
||||||
from .trainer import BaseTrainer
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainer(Seq2SeqTrainer, BaseTrainer):
|
|
||||||
def __init__(self, train_dataset_sizes=None, delta_args=None, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.train_dataset_sizes = train_dataset_sizes
|
|
||||||
self.delta_args = delta_args
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
eval_dataset: Optional[Dict[str, Dataset]] = None,
|
|
||||||
ignore_keys: Optional[List[str]] = None,
|
|
||||||
metric_key_prefix: str = "eval",
|
|
||||||
max_length: Optional[int] = None,
|
|
||||||
num_beams: Optional[int] = None,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
# TODO: this also needs to be set per dataset
|
|
||||||
self._max_length = max_length
|
|
||||||
self._num_beams = num_beams
|
|
||||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: Optional[List[str]] = None,
|
|
||||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (:obj:`nn.Module`):
|
|
||||||
The model to evaluate.
|
|
||||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
|
||||||
The inputs and targets of the model.
|
|
||||||
|
|
||||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
|
||||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
|
||||||
prediction_loss_only (:obj:`bool`):
|
|
||||||
Whether or not to return the loss only.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
|
|
||||||
labels (each being optional).
|
|
||||||
"""
|
|
||||||
if not self.args.predict_with_generate or prediction_loss_only:
|
|
||||||
return super().prediction_step(
|
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
|
||||||
)
|
|
||||||
|
|
||||||
has_labels = "labels" in inputs
|
|
||||||
inputs = self._prepare_inputs(inputs)
|
|
||||||
gen_kwargs = {
|
|
||||||
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
|
|
||||||
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
|
|
||||||
}
|
|
||||||
generated_tokens = self.model.generate(
|
|
||||||
inputs["input_ids"],
|
|
||||||
attention_mask=inputs["attention_mask"],
|
|
||||||
**gen_kwargs,
|
|
||||||
)
|
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
|
||||||
if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
|
|
||||||
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if self.use_amp:
|
|
||||||
with autocast():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
else:
|
|
||||||
outputs = model(**inputs)
|
|
||||||
if has_labels:
|
|
||||||
if self.label_smoother is not None:
|
|
||||||
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
|
|
||||||
else:
|
|
||||||
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
|
|
||||||
else:
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if self.args.prediction_loss_only:
|
|
||||||
return (loss, None, None)
|
|
||||||
|
|
||||||
labels = inputs["labels"]
|
|
||||||
if labels.shape[-1] < gen_kwargs["max_length"]:
|
|
||||||
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
|
|
||||||
|
|
||||||
return (loss, generated_tokens, labels)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,274 +0,0 @@
|
||||||
from typing import Dict, List, Optional
|
|
||||||
import numpy as np
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import collections
|
|
||||||
from packaging import version
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
from transformers import logging
|
|
||||||
from transformers.trainer_utils import (
|
|
||||||
speed_metrics,
|
|
||||||
EvalLoopOutput,
|
|
||||||
denumpify_detensorize
|
|
||||||
)
|
|
||||||
from transformers.file_utils import is_torch_tpu_available
|
|
||||||
from transformers.trainer_pt_utils import (
|
|
||||||
find_batch_size,
|
|
||||||
nested_numpify,
|
|
||||||
nested_truncate,
|
|
||||||
nested_concat,
|
|
||||||
IterableDatasetShard
|
|
||||||
)
|
|
||||||
from .trainer_utils import EvalPrediction
|
|
||||||
|
|
||||||
|
|
||||||
from torch.utils.data.dataloader import DataLoader
|
|
||||||
from torch.utils.data.dataset import IterableDataset
|
|
||||||
from transformers.deepspeed import deepspeed_init
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
|
||||||
from torch.cuda.amp import autocast
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.debug.metrics as met
|
|
||||||
import torch_xla.distributed.parallel_loader as pl
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
class BaseTrainer(Trainer):
|
|
||||||
def __init__(self, evaluation_metrics=[], data_info=None, *args, **kwargs):
|
|
||||||
"""When doing evaluation, it computes average of list of metrics
|
|
||||||
given in evaluation_metrics and adds it to the dictionary of results.
|
|
||||||
Trainer class then use this average metric to save the best model."""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.evaluation_metrics = evaluation_metrics
|
|
||||||
self.data_info = data_info
|
|
||||||
|
|
||||||
def get_data_info(self, metric_key_prefix):
|
|
||||||
"""Returns the data information required to make the predictions/labels
|
|
||||||
suitable for the evaluation."""
|
|
||||||
if self.data_info is not None:
|
|
||||||
return self.data_info[metric_key_prefix]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
eval_dataset: Optional[Dataset] = None,
|
|
||||||
ignore_keys: Optional[List[str]] = None,
|
|
||||||
metric_key_prefix: str = "eval",
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Run evaluation and returns metrics.
|
|
||||||
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
|
|
||||||
(pass it to the init :obj:`compute_metrics` argument).
|
|
||||||
You can also subclass and override this method to inject custom behavior.
|
|
||||||
Args:
|
|
||||||
eval_dataset (:obj:`Dataset`, `optional`):
|
|
||||||
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
|
||||||
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
|
|
||||||
:obj:`__len__` method.
|
|
||||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
|
||||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
|
||||||
gathering predictions.
|
|
||||||
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
|
||||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
|
||||||
"eval_bleu" if the prefix is "eval" (default)
|
|
||||||
Returns:
|
|
||||||
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
|
|
||||||
dictionary also contains the epoch number which comes from the training state.
|
|
||||||
"""
|
|
||||||
# memory metrics - must set up as early as possible
|
|
||||||
self._memory_tracker.start()
|
|
||||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
|
||||||
start_time = time.time()
|
|
||||||
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
|
|
||||||
output = eval_loop(
|
|
||||||
eval_dataloader,
|
|
||||||
description="Evaluation",
|
|
||||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
|
||||||
# self.args.prediction_loss_only
|
|
||||||
prediction_loss_only=True if self.compute_metrics is None else None,
|
|
||||||
ignore_keys=ignore_keys,
|
|
||||||
metric_key_prefix=metric_key_prefix,
|
|
||||||
)
|
|
||||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
|
||||||
if len(self.evaluation_metrics) != 0:
|
|
||||||
selected_metrics = [output.metrics[metric_key_prefix+"_"+k] for k in self.evaluation_metrics if metric_key_prefix+"_"+k in output.metrics]
|
|
||||||
assert len(selected_metrics) >= 1, "at least one metric should be selected to compute the average_metrics."
|
|
||||||
output.metrics.update({metric_key_prefix+'_average_metrics': np.mean(selected_metrics)})
|
|
||||||
|
|
||||||
self.log(output.metrics)
|
|
||||||
|
|
||||||
if self.args.tpu_metrics_debug or self.args.debug:
|
|
||||||
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
|
||||||
xm.master_print(met.metrics_report())
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
|
|
||||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
|
||||||
return output.metrics
|
|
||||||
|
|
||||||
def evaluation_loop(
|
|
||||||
self,
|
|
||||||
dataloader: DataLoader,
|
|
||||||
description: str,
|
|
||||||
prediction_loss_only: Optional[bool] = None,
|
|
||||||
ignore_keys: Optional[List[str]] = None,
|
|
||||||
metric_key_prefix: str = "eval",
|
|
||||||
) -> EvalLoopOutput:
|
|
||||||
"""
|
|
||||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
|
||||||
|
|
||||||
Works both with or without labels.
|
|
||||||
"""
|
|
||||||
prediction_loss_only = (
|
|
||||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
|
||||||
)
|
|
||||||
|
|
||||||
# if eval is called w/o train init deepspeed here
|
|
||||||
if self.args.deepspeed and not self.deepspeed:
|
|
||||||
|
|
||||||
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
|
|
||||||
# from the checkpoint eventually
|
|
||||||
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
|
|
||||||
self.model = deepspeed_engine.module
|
|
||||||
self.model_wrapped = deepspeed_engine
|
|
||||||
self.deepspeed = deepspeed_engine
|
|
||||||
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
|
|
||||||
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
|
|
||||||
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
|
|
||||||
deepspeed_engine.optimizer.optimizer = None
|
|
||||||
deepspeed_engine.lr_scheduler = None
|
|
||||||
|
|
||||||
model = self._wrap_model(self.model, training=False)
|
|
||||||
|
|
||||||
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
|
|
||||||
# ``train`` is running, halve it first and then put on device
|
|
||||||
if not self.is_in_train and self.args.fp16_full_eval:
|
|
||||||
model = model.half().to(self.args.device)
|
|
||||||
|
|
||||||
batch_size = dataloader.batch_size
|
|
||||||
|
|
||||||
logger.info(f"***** Running {description} *****")
|
|
||||||
if isinstance(dataloader.dataset, collections.abc.Sized):
|
|
||||||
logger.info(f" Num examples = {self.num_examples(dataloader)}")
|
|
||||||
else:
|
|
||||||
logger.info(" Num examples: Unknown")
|
|
||||||
logger.info(f" Batch size = {batch_size}")
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
self.callback_handler.eval_dataloader = dataloader
|
|
||||||
# Do this before wrapping.
|
|
||||||
eval_dataset = dataloader.dataset
|
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
|
||||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
|
||||||
|
|
||||||
if self.args.past_index >= 0:
|
|
||||||
self._past = None
|
|
||||||
|
|
||||||
# Initialize containers
|
|
||||||
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
|
|
||||||
losses_host = None
|
|
||||||
preds_host = None
|
|
||||||
labels_host = None
|
|
||||||
# losses/preds/labels on CPU (final containers)
|
|
||||||
all_losses = None
|
|
||||||
all_preds = None
|
|
||||||
all_labels = None
|
|
||||||
# Will be useful when we have an iterable dataset so don't know its length.
|
|
||||||
|
|
||||||
observed_num_examples = 0
|
|
||||||
# Main evaluation loop
|
|
||||||
for step, inputs in enumerate(dataloader):
|
|
||||||
# Update the observed num examples
|
|
||||||
observed_batch_size = find_batch_size(inputs)
|
|
||||||
if observed_batch_size is not None:
|
|
||||||
observed_num_examples += observed_batch_size
|
|
||||||
|
|
||||||
# Prediction step
|
|
||||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
|
||||||
# Update containers on host
|
|
||||||
if loss is not None:
|
|
||||||
losses = self._nested_gather(loss.repeat(batch_size))
|
|
||||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
|
||||||
if logits is not None:
|
|
||||||
logits = self._pad_across_processes(logits)
|
|
||||||
logits = self._nested_gather(logits)
|
|
||||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
|
||||||
if labels is not None:
|
|
||||||
labels = self._pad_across_processes(labels)
|
|
||||||
labels = self._nested_gather(labels)
|
|
||||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
|
||||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
|
||||||
|
|
||||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
|
||||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
|
||||||
if losses_host is not None:
|
|
||||||
losses = nested_numpify(losses_host)
|
|
||||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
|
||||||
if preds_host is not None:
|
|
||||||
logits = nested_numpify(preds_host)
|
|
||||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
|
||||||
if labels_host is not None:
|
|
||||||
labels = nested_numpify(labels_host)
|
|
||||||
all_labels = (
|
|
||||||
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set back to None to begin a new accumulation
|
|
||||||
losses_host, preds_host, labels_host = None, None, None
|
|
||||||
|
|
||||||
if self.args.past_index and hasattr(self, "_past"):
|
|
||||||
# Clean the state at the end of the evaluation loop
|
|
||||||
delattr(self, "_past")
|
|
||||||
|
|
||||||
# Gather all remaining tensors and put them back on the CPU
|
|
||||||
if losses_host is not None:
|
|
||||||
losses = nested_numpify(losses_host)
|
|
||||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
|
||||||
if preds_host is not None:
|
|
||||||
logits = nested_numpify(preds_host)
|
|
||||||
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
|
|
||||||
if labels_host is not None:
|
|
||||||
labels = nested_numpify(labels_host)
|
|
||||||
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
|
|
||||||
|
|
||||||
# Number of samples
|
|
||||||
if not isinstance(eval_dataset, IterableDataset):
|
|
||||||
num_samples = len(eval_dataset)
|
|
||||||
elif isinstance(eval_dataset, IterableDatasetShard):
|
|
||||||
num_samples = eval_dataset.num_examples
|
|
||||||
else:
|
|
||||||
num_samples = observed_num_examples
|
|
||||||
|
|
||||||
# Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
|
|
||||||
# samplers has been rounded to a multiple of batch_size, so we truncate.
|
|
||||||
if all_losses is not None:
|
|
||||||
all_losses = all_losses[:num_samples]
|
|
||||||
if all_preds is not None:
|
|
||||||
all_preds = nested_truncate(all_preds, num_samples)
|
|
||||||
if all_labels is not None:
|
|
||||||
all_labels = nested_truncate(all_labels, num_samples)
|
|
||||||
# Metrics!
|
|
||||||
if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
|
|
||||||
metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels,
|
|
||||||
data_info=self.get_data_info(metric_key_prefix)))
|
|
||||||
else:
|
|
||||||
metrics = {}
|
|
||||||
|
|
||||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
|
||||||
metrics = denumpify_detensorize(metrics)
|
|
||||||
|
|
||||||
if all_losses is not None:
|
|
||||||
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
|
|
||||||
|
|
||||||
# Prefix all keys with metric_key_prefix + '_'
|
|
||||||
for key in list(metrics.keys()):
|
|
||||||
if not key.startswith(f"{metric_key_prefix}_"):
|
|
||||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
|
|
||||||
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
|
|
|
@ -1,75 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from typing import Union, NamedTuple, Tuple, Dict, Any
|
|
||||||
import os
|
|
||||||
import regex as re
|
|
||||||
import logging
|
|
||||||
from dataclasses import fields
|
|
||||||
import torch.nn as nn
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
class EvalPrediction(NamedTuple):
|
|
||||||
"""
|
|
||||||
Evaluation output (always contains labels), to be used to compute metrics.
|
|
||||||
Parameters:
|
|
||||||
predictions (:obj:`np.ndarray`): Predictions of the model.
|
|
||||||
label_ids (:obj:`np.ndarray`): Targets to be matched.
|
|
||||||
data_info: (:obj:`Dict[str, Any]`): Extra dataset information, one requires
|
|
||||||
to performs the evaluation. The data_info is a dictionary with keys from
|
|
||||||
train, eval, test to specify the data_info for each split of the dataset.
|
|
||||||
"""
|
|
||||||
predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
|
||||||
label_ids: np.ndarray
|
|
||||||
data_info: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_dir(output_dir):
|
|
||||||
"""
|
|
||||||
Checks whether to the output_dir already exists and creates it if not.
|
|
||||||
Args:
|
|
||||||
output_dir: path to the output_dir
|
|
||||||
"""
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_checkpoint(output_dir):
|
|
||||||
if os.path.exists(os.path.join(output_dir, 'pytorch_model.bin')):
|
|
||||||
return output_dir
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def pad_punctuation(text):
|
|
||||||
"""Re-implementation of _pad_punctuation in t5. This function adds spaces
|
|
||||||
around punctuation. While this pads punctuation as expected, it has the
|
|
||||||
unexpected effected of padding certain unicode characters with accents, with
|
|
||||||
spaces as well. For instance: "François" becomes "Fran ç ois"""
|
|
||||||
# Pad everything except for: underscores (_), whitespace (\s),
|
|
||||||
# numbers (\p{N}), letters (\p{L}) and accent characters (\p{M}).
|
|
||||||
text = re.sub(r'([^_\s\p{N}\p{L}\p{M}])', r' \1 ', text)
|
|
||||||
# Collapse consecutive whitespace into one space.
|
|
||||||
text = re.sub(r'\s+', ' ', text)
|
|
||||||
return text
|
|
||||||
|
|
||||||
def save_json(filepath, dictionary):
|
|
||||||
with open(filepath, "w") as outfile:
|
|
||||||
json.dump(dictionary, outfile)
|
|
||||||
|
|
||||||
|
|
||||||
def read_json(filepath):
|
|
||||||
f = open(filepath,)
|
|
||||||
return json.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def save_training_config(config_file, output_dir):
|
|
||||||
json_data = read_json(config_file)
|
|
||||||
save_json(os.path.join(output_dir, "training_config.json"), json_data)
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .utils import *
|
|
@ -1,10 +1,51 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||||
|
"""
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||||
|
)
|
||||||
|
config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
tokenizer_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
||||||
|
)
|
||||||
|
cache_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
||||||
|
)
|
||||||
|
use_fast_tokenizer: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
||||||
|
)
|
||||||
|
model_revision: str = field(
|
||||||
|
default="main",
|
||||||
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
|
)
|
||||||
|
use_auth_token: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||||
|
"with private models)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
num_classes:Optional[int]=field(
|
||||||
|
default=None, metadata={"help": "The number of classes, used to initialize classification models"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import TrainingArguments as HfTrainingArguments
|
||||||
# run_seq2seq parameters.
|
# run_seq2seq parameters.
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(Seq2SeqTrainingArguments):
|
class TrainingArguments(HfTrainingArguments):
|
||||||
print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
|
print_num_parameters: Optional[bool] = field(default=False, metadata={"help": "If set, print the parameters of "
|
||||||
"the model."})
|
"the model."})
|
||||||
do_test: Optional[bool] = field(default=False, metadata={"help": "If set, evaluates the test performance."})
|
do_test: Optional[bool] = field(default=False, metadata={"help": "If set, evaluates the test performance."})
|
||||||
|
@ -16,9 +57,31 @@ class TrainingArguments(Seq2SeqTrainingArguments):
|
||||||
"than 10K samples datasets), or by using 1K examples"
|
"than 10K samples datasets), or by using 1K examples"
|
||||||
"from training set as validation set (for larger"
|
"from training set as validation set (for larger"
|
||||||
" datasets)."})
|
" datasets)."})
|
||||||
compute_time: Optional[bool] = field(default=False, metadata={"help": "If set measures the time."})
|
compute_time: Optional[bool] = field(default=True, metadata={"help": "If set measures the time."})
|
||||||
compute_memory: Optional[bool] = field(default=False, metadata={"help": "if set, measures the memory"})
|
compute_memory: Optional[bool] = field(default=True, metadata={"help": "if set, measures the memory"})
|
||||||
is_seq2seq: Optional[bool] = field(default=True, metadata={"help": "whether the pipeline is a seq2seq one"})
|
is_seq2seq: Optional[bool] = field(default=True, metadata={"help": "whether the pipeline is a seq2seq one"})
|
||||||
|
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
|
||||||
|
predict_with_generate: bool = field(
|
||||||
|
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
|
||||||
|
)
|
||||||
|
generation_max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||||
|
"to the `max_length` value of the model configuration."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generation_num_beams: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
|
||||||
|
"to the `num_beams` value of the model configuration."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
remove_unused_columns: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +111,7 @@ class DataTrainingArguments:
|
||||||
default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
|
default=None, metadata={"help": "The configuration name of the test dataset to use (via the datasets library)."}
|
||||||
)
|
)
|
||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=True, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -134,7 +197,6 @@ class DataTrainingArguments:
|
||||||
datasets_saved_path: Optional[str] = field(
|
datasets_saved_path: Optional[str] = field(
|
||||||
default=None, metadata={"help": "the path of the saved datasets"}
|
default=None, metadata={"help": "the path of the saved datasets"}
|
||||||
)
|
)
|
||||||
|
|
||||||
data_seed: Optional[int] = field(default=42, metadata={"help": "seed used to shuffle the data."})
|
data_seed: Optional[int] = field(default=42, metadata={"help": "seed used to shuffle the data."})
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,3 +209,30 @@ class DataTrainingArguments:
|
||||||
self.val_max_target_length = self.max_target_length
|
self.val_max_target_length = self.max_target_length
|
||||||
if self.test_max_target_length is None:
|
if self.test_max_target_length is None:
|
||||||
self.test_max_target_length = self.max_target_length
|
self.test_max_target_length = self.max_target_length
|
||||||
|
|
||||||
|
|
||||||
|
class RemainArgHfArgumentParser(HfArgumentParser):
|
||||||
|
def parse_json_file(self, json_file: str, return_remaining_args=True ):
|
||||||
|
"""
|
||||||
|
Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
|
||||||
|
dataclass types.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
data = json.loads(Path(json_file).read_text())
|
||||||
|
outputs = []
|
||||||
|
for dtype in self.dataclass_types:
|
||||||
|
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
||||||
|
inputs = {k: data.pop(k) for k in list(data.keys()) if k in keys}
|
||||||
|
obj = dtype(**inputs)
|
||||||
|
outputs.append(obj)
|
||||||
|
|
||||||
|
remain_args = argparse.ArgumentParser()
|
||||||
|
remain_args.__dict__.update(data)
|
||||||
|
if return_remaining_args:
|
||||||
|
return (*outputs, remain_args)
|
||||||
|
else:
|
||||||
|
return (*outputs,)
|
|
@ -1,15 +1,48 @@
|
||||||
import os
|
|
||||||
import regex as re
|
|
||||||
import logging
|
|
||||||
from dataclasses import fields
|
|
||||||
import torch.nn as nn
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
# class EvalPrediction(NamedTuple):
|
||||||
|
# """
|
||||||
|
# Evaluation output (always contains labels), to be used to compute metrics.
|
||||||
|
# Parameters:
|
||||||
|
# predictions (:obj:`np.ndarray`): Predictions of the model.
|
||||||
|
# label_ids (:obj:`np.ndarray`): Targets to be matched.
|
||||||
|
# data_info: (:obj:`Dict[str, Any]`): Extra dataset information, one requires
|
||||||
|
# to performs the evaluation. The data_info is a dictionary with keys from
|
||||||
|
# train, eval, test to specify the data_info for each split of the dataset.
|
||||||
|
# """
|
||||||
|
# predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
||||||
|
# label_ids: np.ndarray
|
||||||
|
# data_info: Dict[str, Any]
|
||||||
|
|
||||||
|
def create_dir(output_dir):
|
||||||
|
"""
|
||||||
|
Checks whether to the output_dir already exists and creates it if not.
|
||||||
|
Args:
|
||||||
|
output_dir: path to the output_dir
|
||||||
|
"""
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_checkpoint(output_dir):
|
||||||
|
if os.path.exists(os.path.join(output_dir, 'pytorch_model.bin')):
|
||||||
|
return output_dir
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(filepath, dictionary):
|
||||||
|
with open(filepath, "w") as outfile:
|
||||||
|
json.dump(dictionary, outfile)
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(filepath):
|
||||||
|
f = open(filepath,)
|
||||||
|
return json.load(f)
|
|
@ -66,7 +66,7 @@ class SoftPromptLayer(nn.Module):
|
||||||
assert self.num_tokens>0
|
assert self.num_tokens>0
|
||||||
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
self.instantiate(raw_embedding(torch.tensor([0])).shape[-1])
|
||||||
|
|
||||||
self.all_pseudo_tokens = {}
|
# self.all_pseudo_tokens = {}
|
||||||
|
|
||||||
def pre_forward(self, *args, **kwargs):
|
def pre_forward(self, *args, **kwargs):
|
||||||
# if attention_mask is passed as PLM's input, modify it here
|
# if attention_mask is passed as PLM's input, modify it here
|
||||||
|
@ -108,15 +108,15 @@ class SoftPromptLayer(nn.Module):
|
||||||
for expand_key in self.other_expand_ids:
|
for expand_key in self.other_expand_ids:
|
||||||
if expand_key in kwargs:
|
if expand_key in kwargs:
|
||||||
real_tokens = kwargs[expand_key]
|
real_tokens = kwargs[expand_key]
|
||||||
if expand_key in self.all_pseudo_tokens:
|
# if expand_key in self.all_pseudo_tokens:
|
||||||
pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
|
# pseudo_tokens = self.all_pseudo_tokens[expand_key].to(real_tokens.device)
|
||||||
else:
|
# else:
|
||||||
pseudo_tokens_value = self.other_expand_ids[expand_key]
|
pseudo_tokens_value = self.other_expand_ids[expand_key]
|
||||||
pseudo_tokens = torch.ones(
|
pseudo_tokens = torch.ones(
|
||||||
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
|
(*real_tokens.shape[:-1], inputs_embeds.shape[-2]-real_tokens.shape[-1]),
|
||||||
dtype = real_tokens.dtype,
|
dtype = real_tokens.dtype,
|
||||||
device=real_tokens.device) * pseudo_tokens_value
|
device=real_tokens.device) * pseudo_tokens_value
|
||||||
self.all_pseudo_tokens[expand_key] = pseudo_tokens
|
# self.all_pseudo_tokens[expand_key] = pseudo_tokens
|
||||||
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
|
real_tokens.data = torch.cat([pseudo_tokens, real_tokens], dim=-1)
|
||||||
|
|
||||||
return args, kwargs
|
return args, kwargs
|
||||||
|
|
Loading…
Reference in New Issue