246 lines
8.6 KiB
Python
246 lines
8.6 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2023 THU-KEG & Zhipu AI
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import argparse
|
|
from collections import Counter
|
|
import re
|
|
import string
|
|
|
|
from fuzzywuzzy import fuzz
|
|
import jieba
|
|
import numpy as np
|
|
from rouge import Rouge
|
|
|
|
|
|
def parse_args(args=None):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model', type=str, default=None)
|
|
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
|
|
return parser.parse_args(args)
|
|
|
|
|
|
def normalize_answer(s):
|
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
|
|
|
def remove_articles(text):
|
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
|
|
def white_space_fix(text):
|
|
return " ".join(text.split())
|
|
|
|
def remove_punc(text):
|
|
exclude = set(string.punctuation)
|
|
return "".join(ch for ch in text if ch not in exclude)
|
|
|
|
def lower(text):
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
|
|
|
|
def normalize_zh_answer(s):
|
|
"""Lower text and remove punctuation, extra whitespace."""
|
|
|
|
def white_space_fix(text):
|
|
return "".join(text.split())
|
|
|
|
def remove_punc(text):
|
|
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
|
all_punctuation = set(string.punctuation + cn_punctuation)
|
|
return "".join(ch for ch in text if ch not in all_punctuation)
|
|
|
|
def lower(text):
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_punc(lower(s)))
|
|
|
|
|
|
def count_score(prediction, ground_truth, **kwargs):
|
|
numbers = re.findall(r"\d+", prediction)
|
|
right_num = 0
|
|
for number in numbers:
|
|
if str(number) == str(ground_truth):
|
|
right_num += 1
|
|
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
|
return float(final_score)
|
|
|
|
|
|
def retrieval_score(prediction, ground_truth, **kwargs):
|
|
pattern = r'Paragraph (\d+)'
|
|
matches = re.findall(pattern, ground_truth)
|
|
ground_truth_id = matches[0]
|
|
numbers = re.findall(r"\d+", prediction)
|
|
right_num = 0
|
|
for number in numbers:
|
|
if str(number) == str(ground_truth_id):
|
|
right_num += 1
|
|
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
|
return float(final_score)
|
|
|
|
|
|
def retrieval_zh_score(prediction, ground_truth, **kwargs):
|
|
pattern = r'段落(\d+)'
|
|
matches = re.findall(pattern, ground_truth)
|
|
ground_truth_id = matches[0]
|
|
numbers = re.findall(r"\d+", prediction)
|
|
right_num = 0
|
|
for number in numbers:
|
|
if str(number) == str(ground_truth_id):
|
|
right_num += 1
|
|
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
|
return float(final_score)
|
|
|
|
|
|
def code_sim_score(prediction, ground_truth, **kwargs):
|
|
all_lines = prediction.lstrip('\n').split('\n')
|
|
prediction = ""
|
|
for line in all_lines:
|
|
if ('`' not in line) and ('#' not in line) and ('//' not in line):
|
|
prediction = line
|
|
break
|
|
return (fuzz.ratio(prediction, ground_truth) / 100)
|
|
|
|
|
|
def classification_score(prediction, ground_truth, **kwargs):
|
|
em_match_list = []
|
|
all_classes = kwargs.get("all_classes", [])
|
|
for class_name in all_classes:
|
|
if class_name in prediction:
|
|
em_match_list.append(class_name)
|
|
for match_term in em_match_list:
|
|
if match_term in ground_truth and match_term != ground_truth:
|
|
em_match_list.remove(match_term)
|
|
if ground_truth in em_match_list and len(em_match_list) != 0:
|
|
score = (1.0 / len(em_match_list))
|
|
else:
|
|
score = 0.0
|
|
return score
|
|
|
|
|
|
def rouge_score(prediction, ground_truth, **kwargs):
|
|
rouge = Rouge()
|
|
try:
|
|
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
|
|
except ValueError:
|
|
return 0.0
|
|
except Exception:
|
|
return 0.0
|
|
return scores["rouge-l"]["f"]
|
|
|
|
|
|
def rouge_zh_score(prediction, ground_truth, **kwargs):
|
|
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
|
|
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
|
|
score = rouge_score(prediction, ground_truth)
|
|
return score
|
|
|
|
|
|
def f1_score(prediction, ground_truth, **kwargs):
|
|
common = Counter(prediction) & Counter(ground_truth)
|
|
num_same = sum(common.values())
|
|
if num_same == 0:
|
|
return 0.0
|
|
if len(prediction) != 0 and len(ground_truth) != 0:
|
|
precision = 1.0 * num_same / len(prediction)
|
|
recall = 1.0 * num_same / len(ground_truth)
|
|
f1 = (2 * precision * recall) / (precision + recall)
|
|
else:
|
|
f1 = 0.0
|
|
return f1
|
|
|
|
|
|
def qa_f1_score(prediction, ground_truth, **kwargs):
|
|
normalized_prediction = normalize_answer(prediction)
|
|
normalized_ground_truth = normalize_answer(ground_truth)
|
|
|
|
prediction_tokens = normalized_prediction.split()
|
|
ground_truth_tokens = normalized_ground_truth.split()
|
|
return f1_score(prediction_tokens, ground_truth_tokens)
|
|
|
|
|
|
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
|
|
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
|
|
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
|
|
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
|
|
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
|
|
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
|
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
|
return f1_score(prediction_tokens, ground_truth_tokens)
|
|
|
|
|
|
dataset2metric = {
|
|
"narrativeqa": qa_f1_score,
|
|
"qasper": qa_f1_score,
|
|
"multifieldqa_en": qa_f1_score,
|
|
"multifieldqa_zh": qa_f1_zh_score,
|
|
"hotpotqa": qa_f1_score,
|
|
"2wikimqa": qa_f1_score,
|
|
"musique": qa_f1_score,
|
|
"dureader": rouge_zh_score,
|
|
"gov_report": rouge_score,
|
|
"qmsum": rouge_score,
|
|
"multi_news": rouge_score,
|
|
"vcsum": rouge_zh_score,
|
|
"trec": classification_score,
|
|
"triviaqa": qa_f1_score,
|
|
"samsum": rouge_score,
|
|
"lsht": classification_score,
|
|
"passage_retrieval_en": retrieval_score,
|
|
"passage_count": count_score,
|
|
"passage_retrieval_zh": retrieval_zh_score,
|
|
"lcc": code_sim_score,
|
|
"repobench-p": code_sim_score,
|
|
}
|
|
|
|
|
|
def scorer_e(dataset, predictions, answers, lengths, all_classes):
|
|
scores = {"0-4k": [], "4-8k": [], "8k+": []}
|
|
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
|
|
score = 0.
|
|
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
|
prediction = prediction.lstrip('\n').split('\n')[0]
|
|
for ground_truth in ground_truths:
|
|
score = max(score, dataset2metric.get(dataset, qa_f1_score)(prediction, ground_truth, all_classes=all_classes))
|
|
if length < 4000:
|
|
scores["0-4k"].append(score)
|
|
elif length < 8000:
|
|
scores["4-8k"].append(score)
|
|
else:
|
|
scores["8k+"].append(score)
|
|
for key in scores.keys():
|
|
scores[key] = round(100 * np.mean(scores[key]), 2)
|
|
return scores
|
|
|
|
|
|
def scorer(dataset, predictions, answers, all_classes):
|
|
total_score = 0.
|
|
for (prediction, ground_truths) in zip(predictions, answers):
|
|
score = 0.
|
|
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
|
|
prediction = prediction.lstrip('\n').split('\n')[0]
|
|
for ground_truth in ground_truths:
|
|
score = max(score, dataset2metric.get(dataset, qa_f1_score)(prediction, ground_truth, all_classes=all_classes))
|
|
total_score += score
|
|
if len(predictions) == 0:
|
|
return 0.0
|
|
else:
|
|
return round(100 * total_score / len(predictions), 2) |