add CMMLU, update eval script

This commit is contained in:
hiyouga 2023-09-23 21:10:17 +08:00
parent f8ff625d76
commit 4dd9b4d982
7 changed files with 507 additions and 61 deletions

View File

@ -14,7 +14,7 @@
## Changelog ## Changelog
[23/09/23] We integrated MMLU and C-Eval benchmarks in this repo. See [this example](#evaluation-mmlu--c-eval) to evaluate your models. [23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs. [23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
@ -371,7 +371,8 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export \
--fp16
``` ```
### API Demo ### API Demo
@ -407,7 +408,22 @@ python src/web_demo.py \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### Evaluation and Predict (BLEU & ROUGE_CHINESE) ### Evaluation
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
### Predict
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -425,22 +441,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
``` ```
> [!NOTE] > [!NOTE]
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. > We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
### Evaluation (MMLU & C-Eval)
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task mmlu \
--split test \
--lang en \
--n_shot 5 \
--batch_size 4
```
## License ## License

View File

@ -14,7 +14,7 @@
## 更新日志 ## 更新日志
[23/09/23] 我们在项目中集成了 MMLU 和 C-Eval 评估集。使用方法请参阅[此示例](#模型评估mmlu-和-c-eval)。 [23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能 [23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU请使用 `--flash_attn` 参数以启用 FlashAttention-2实验性功能
@ -370,7 +370,8 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export \
--fp16
``` ```
### API 服务 ### API 服务
@ -406,7 +407,22 @@ python src/web_demo.py \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### 指标评估与模型预测BLEU 分数和汉语 ROUGE 分数) ### 模型评估
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
### 模型预测
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@ -424,22 +440,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
``` ```
> [!NOTE] > [!NOTE]
> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` > 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型评估MMLU 和 C-Eval
```bash
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
--model_name_or_path path_to_llama_model \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--template vanilla \
--task ceval \
--split validation \
--lang zh \
--n_shot 5 \
--batch_size 4
```
## 协议 ## 协议

View File

@ -92,14 +92,14 @@ task_list = [
] ]
class CevalExamConfig(datasets.BuilderConfig): class CevalConfig(datasets.BuilderConfig):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.0"), **kwargs) super().__init__(version=datasets.Version("1.0.0"), **kwargs)
class CevalExam(datasets.GeneratorBasedBuilder): class Ceval(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [ BUILDER_CONFIGS = [
CevalExamConfig( CevalConfig(
name=task_name, name=task_name,
) )
for task_name in task_list for task_name in task_list

163
evaluation/cmmlu/cmmlu.py Normal file
View File

@ -0,0 +1,163 @@
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import datasets
import pandas as pd
_CITATION = """\
@article{li2023cmmlu,
title={CMMLU: Measuring massive multitask language understanding in Chinese},
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
journal={arXiv preprint arXiv:2306.09212},
year={2023}
}
"""
_DESCRIPTION = """\
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
"""
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
_URL = "cmmlu.zip"
task_list = [
'agronomy',
'anatomy',
'ancient_chinese',
'arts',
'astronomy',
'business_ethics',
'chinese_civil_service_exam',
'chinese_driving_rule',
'chinese_food_culture',
'chinese_foreign_policy',
'chinese_history',
'chinese_literature',
'chinese_teacher_qualification',
'clinical_knowledge',
'college_actuarial_science',
'college_education',
'college_engineering_hydrology',
'college_law',
'college_mathematics',
'college_medical_statistics',
'college_medicine',
'computer_science',
'computer_security',
'conceptual_physics',
'construction_project_management',
'economics',
'education',
'electrical_engineering',
'elementary_chinese',
'elementary_commonsense',
'elementary_information_and_technology',
'elementary_mathematics',
'ethnology',
'food_science',
'genetics',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_geography',
'high_school_mathematics',
'high_school_physics',
'high_school_politics',
'human_sexuality',
'international_law',
'journalism',
'jurisprudence',
'legal_and_moral_basis',
'logical',
'machine_learning',
'management',
'marketing',
'marxist_theory',
'modern_chinese',
'nutrition',
'philosophy',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_study',
'sociology',
'sports_science',
'traditional_chinese_medicine',
'virology',
'world_history',
'world_religions',
]
class CMMLUConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
class CMMLU(datasets.GeneratorBasedBuilder):
BUILDER_CONFIGS = [
CMMLUConfig(
name=task_name,
)
for task_name in task_list
]
def _info(self):
features = datasets.Features(
{
"question": datasets.Value("string"),
"A": datasets.Value("string"),
"B": datasets.Value("string"),
"C": datasets.Value("string"),
"D": datasets.Value("string"),
"answer": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_dir = dl_manager.download_and_extract(_URL)
task_name = self.config.name
return [
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
},
),
]
def _generate_examples(self, filepath):
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
for i, instance in enumerate(df.to_dict(orient="records")):
yield i, instance

BIN
evaluation/cmmlu/cmmlu.zip Normal file

Binary file not shown.

View File

@ -0,0 +1,270 @@
{
"agronomy": {
"name": "农学",
"category": "Other"
},
"anatomy": {
"name": "解剖学",
"category": "STEM"
},
"ancient_chinese": {
"name": "古汉语",
"category": "Social Sciences"
},
"arts": {
"name": "艺术学",
"category": "Humanities"
},
"astronomy": {
"name": "天文学",
"category": "STEM"
},
"business_ethics": {
"name": "商业伦理",
"category": "Social Sciences"
},
"chinese_civil_service_exam": {
"name": "中国公务员考试",
"category": "Social Sciences"
},
"chinese_driving_rule": {
"name": "中国驾驶规则",
"category": "Other"
},
"chinese_food_culture": {
"name": "中国饮食文化",
"category": "Social Sciences"
},
"chinese_foreign_policy": {
"name": "中国外交政策",
"category": "Social Sciences"
},
"chinese_history": {
"name": "中国历史",
"category": "Humanities"
},
"chinese_literature": {
"name": "中国文学",
"category": "Humanities"
},
"chinese_teacher_qualification": {
"name": "中国教师资格",
"category": "Social Sciences"
},
"college_actuarial_science": {
"name": "大学精算学",
"category": "STEM"
},
"college_education": {
"name": "大学教育学",
"category": "Social Sciences"
},
"college_engineering_hydrology": {
"name": "大学工程水文学",
"category": "STEM"
},
"college_law": {
"name": "大学法律",
"category": "Humanities"
},
"college_mathematics": {
"name": "大学数学",
"category": "STEM"
},
"college_medical_statistics": {
"name": "大学医学统计",
"category": "STEM"
},
"clinical_knowledge": {
"name": "临床知识",
"category": "Other"
},
"college_medicine": {
"name": "大学医学",
"category": "Other"
},
"computer_science": {
"name": "计算机科学",
"category": "STEM"
},
"computer_security": {
"name": "计算机安全",
"category": "Other"
},
"conceptual_physics": {
"name": "概念物理学",
"category": "STEM"
},
"construction_project_management": {
"name": "建设工程管理",
"category": "Other"
},
"economics": {
"name": "经济学",
"category": "Social Sciences"
},
"education": {
"name": "教育学",
"category": "Social Sciences"
},
"elementary_chinese": {
"name": "小学语文",
"category": "Social Sciences"
},
"elementary_commonsense": {
"name": "小学常识",
"category": "Other"
},
"elementary_information_and_technology": {
"name": "小学信息技术",
"category": "Other"
},
"electrical_engineering": {
"name": "电气工程",
"category": "STEM"
},
"elementary_mathematics": {
"name": "初等数学",
"category": "STEM"
},
"ethnology": {
"name": "民族学",
"category": "Social Sciences"
},
"food_science": {
"name": "食品科学",
"category": "Other"
},
"genetics": {
"name": "遗传学",
"category": "STEM"
},
"global_facts": {
"name": "全球事实",
"category": "Humanities"
},
"high_school_biology": {
"name": "高中生物",
"category": "STEM"
},
"high_school_chemistry": {
"name": "高中化学",
"category": "STEM"
},
"high_school_geography": {
"name": "高中地理",
"category": "Social Sciences"
},
"high_school_mathematics": {
"name": "高中数学",
"category": "STEM"
},
"high_school_physics": {
"name": "高中物理学",
"category": "STEM"
},
"high_school_politics": {
"name": "高中政治",
"category": "Social Sciences"
},
"human_sexuality": {
"name": "人类性行为",
"category": "Other"
},
"international_law": {
"name": "国际法学",
"category": "Humanities"
},
"journalism": {
"name": "新闻学",
"category": "Social Sciences"
},
"jurisprudence": {
"name": "法理学",
"category": "Humanities"
},
"legal_and_moral_basis": {
"name": "法律与道德基础",
"category": "Other"
},
"logical": {
"name": "逻辑学",
"category": "Humanities"
},
"machine_learning": {
"name": "机器学习",
"category": "STEM"
},
"management": {
"name": "管理学",
"category": "Social Sciences"
},
"marketing": {
"name": "市场营销",
"category": "Social Sciences"
},
"marxist_theory": {
"name": "马克思主义理论",
"category": "Humanities"
},
"modern_chinese": {
"name": "现代汉语",
"category": "Social Sciences"
},
"nutrition": {
"name": "营养学",
"category": "Other"
},
"philosophy": {
"name": "哲学",
"category": "Humanities"
},
"professional_accounting": {
"name": "专业会计",
"category": "Social Sciences"
},
"professional_law": {
"name": "专业法学",
"category": "Humanities"
},
"professional_medicine": {
"name": "专业医学",
"category": "Other"
},
"professional_psychology": {
"name": "专业心理学",
"category": "Social Sciences"
},
"public_relations": {
"name": "公共关系",
"category": "Social Sciences"
},
"security_study": {
"name": "安全研究",
"category": "Social Sciences"
},
"sociology": {
"name": "社会学",
"category": "Social Sciences"
},
"sports_science": {
"name": "体育学",
"category": "Other"
},
"traditional_chinese_medicine": {
"name": "中医中药",
"category": "Other"
},
"virology": {
"name": "病毒学",
"category": "STEM"
},
"world_history": {
"name": "世界历史",
"category": "Humanities"
},
"world_religions": {
"name": "世界宗教",
"category": "Humanities"
}
}

View File

@ -1,16 +1,15 @@
# coding=utf-8 # coding=utf-8
# Evaluates the performance of pre-trained models. # Evaluates the performance of pre-trained models.
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla # Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 # --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
import os import os
import fire import fire
import json import json
import torch import torch
import random
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
@ -30,6 +29,7 @@ class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str
def parse_example( def parse_example(
self, self,
@ -49,7 +49,6 @@ class EvalTemplate:
history = [self.parse_example(support_set[k]) for k in range(len(support_set))] history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
if len(history): if len(history):
random.shuffle(history)
temp = history.pop(0) temp = history.pop(0)
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1])) history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
else: else:
@ -65,12 +64,14 @@ eval_templates = {
"en": EvalTemplate( "en": EvalTemplate(
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: " answer="\nAnswer: ",
prefix=" "
), ),
"zh": EvalTemplate( "zh": EvalTemplate(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:" answer="\n答案:",
prefix="\n"
) )
} }
@ -79,9 +80,8 @@ eval_templates = {
def batch_inference( def batch_inference(
chat_model: ChatModel, chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor], batch_input: Dict[str, torch.Tensor],
lang: Literal["zh", "en"] prefix_char: str
) -> List[str]: ) -> List[str]:
prefix_char = "\n" if lang == "zh" else " "
logits = chat_model.model(**batch_input).logits logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax( probs = torch.nn.functional.softmax(
torch.stack( torch.stack(
@ -108,7 +108,8 @@ def evaluate(
split: Optional[Literal["validation", "test"]] = "validation", split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh", lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
batch_size: Optional[int] = 4 batch_size: Optional[int] = 4,
save_name: Optional[str] = None
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f) categorys = json.load(f)
@ -119,25 +120,25 @@ def evaluate(
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
template=template template=template
)) ))
chat_model.tokenizer.padding_side = "left"
eval_template = eval_templates[lang] eval_template = eval_templates[lang]
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = { category_corrects: Dict[str, np.ndarray] = {
"STEM": np.array([], dtype="bool"), subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"]
"Social Sciences": np.array([], dtype="bool"),
"Humanities": np.array([], dtype="bool"),
"Other": np.array([], dtype="bool")
} }
overall_corrects = np.array([], dtype="bool") overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys()) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {}
for subject in pbar: for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], [] inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject) dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])): for i in range(len(dataset[split])):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example( query, resp, history = eval_template.format_example(
target_data=dataset[split][i], target_data=dataset[split][i],
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))), support_set=support_set,
subject_name=categorys[subject]["name"], subject_name=categorys[subject]["name"],
use_history=chat_model.template.use_history use_history=chat_model.template.use_history
) )
@ -154,23 +155,33 @@ def evaluate(
labels.append(resp) labels.append(resp)
outputs = [] outputs = []
for i in range(0, len(inputs), batch_size): for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad( batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], inputs[i : i + batch_size],
return_attention_mask=True, return_attention_mask=True,
return_tensors="pt" return_tensors="pt"
).to(chat_model.model.device) ).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, lang) preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds outputs += preds
corrects = (np.array(outputs) == np.array(labels)) corrects = (np.array(outputs) == np.array(labels))
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0) overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))) score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))
for category_name, category_correct in category_corrects.items(): for category_name, category_correct in category_corrects.items():
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct))) if len(category_correct):
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
print(score_info)
if save_name is not None:
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
json.dump(results, f, indent=2)
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
f.write(score_info)
if __name__ == "__main__": if __name__ == "__main__":