forked from p04798526/LLaMA-Factory-Mirror
add CMMLU, update eval script
This commit is contained in:
parent
f8ff625d76
commit
4dd9b4d982
39
README.md
39
README.md
|
@ -14,7 +14,7 @@
|
|||
|
||||
## Changelog
|
||||
|
||||
[23/09/23] We integrated MMLU and C-Eval benchmarks in this repo. See [this example](#evaluation-mmlu--c-eval) to evaluate your models.
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/09/10] We supported using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
|
@ -371,7 +371,8 @@ python src/export_model.py \
|
|||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
--output_dir path_to_export \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### API Demo
|
||||
|
@ -407,7 +408,22 @@ python src/web_demo.py \
|
|||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Evaluation and Predict (BLEU & ROUGE_CHINESE)
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
### Predict
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -425,22 +441,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
```
|
||||
|
||||
> [!NOTE]
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### Evaluation (MMLU & C-Eval)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
||||
|
||||
## License
|
||||
|
||||
|
|
39
README_zh.md
39
README_zh.md
|
@ -14,7 +14,7 @@
|
|||
|
||||
## 更新日志
|
||||
|
||||
[23/09/23] 我们在项目中集成了 MMLU 和 C-Eval 评估集。使用方法请参阅[此示例](#模型评估mmlu-和-c-eval)。
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/09/10] 我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。
|
||||
|
||||
|
@ -370,7 +370,8 @@ python src/export_model.py \
|
|||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
--output_dir path_to_export \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### API 服务
|
||||
|
@ -406,7 +407,22 @@ python src/web_demo.py \
|
|||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 指标评估与模型预测(BLEU 分数和汉语 ROUGE 分数)
|
||||
### 模型评估
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
### 模型预测
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
@ -424,22 +440,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||
```
|
||||
|
||||
> [!NOTE]
|
||||
> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
### 模型评估(MMLU 和 C-Eval)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
## 协议
|
||||
|
||||
|
|
|
@ -92,14 +92,14 @@ task_list = [
|
|||
]
|
||||
|
||||
|
||||
class CevalExamConfig(datasets.BuilderConfig):
|
||||
class CevalConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class CevalExam(datasets.GeneratorBasedBuilder):
|
||||
class Ceval(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CevalExamConfig(
|
||||
CevalConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{li2023cmmlu,
|
||||
title={CMMLU: Measuring massive multitask language understanding in Chinese},
|
||||
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
|
||||
journal={arXiv preprint arXiv:2306.09212},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "cmmlu.zip"
|
||||
|
||||
task_list = [
|
||||
'agronomy',
|
||||
'anatomy',
|
||||
'ancient_chinese',
|
||||
'arts',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'chinese_civil_service_exam',
|
||||
'chinese_driving_rule',
|
||||
'chinese_food_culture',
|
||||
'chinese_foreign_policy',
|
||||
'chinese_history',
|
||||
'chinese_literature',
|
||||
'chinese_teacher_qualification',
|
||||
'clinical_knowledge',
|
||||
'college_actuarial_science',
|
||||
'college_education',
|
||||
'college_engineering_hydrology',
|
||||
'college_law',
|
||||
'college_mathematics',
|
||||
'college_medical_statistics',
|
||||
'college_medicine',
|
||||
'computer_science',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'construction_project_management',
|
||||
'economics',
|
||||
'education',
|
||||
'electrical_engineering',
|
||||
'elementary_chinese',
|
||||
'elementary_commonsense',
|
||||
'elementary_information_and_technology',
|
||||
'elementary_mathematics',
|
||||
'ethnology',
|
||||
'food_science',
|
||||
'genetics',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_geography',
|
||||
'high_school_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_politics',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'journalism',
|
||||
'jurisprudence',
|
||||
'legal_and_moral_basis',
|
||||
'logical',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'marxist_theory',
|
||||
'modern_chinese',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_study',
|
||||
'sociology',
|
||||
'sports_science',
|
||||
'traditional_chinese_medicine',
|
||||
'virology',
|
||||
'world_history',
|
||||
'world_religions',
|
||||
]
|
||||
|
||||
|
||||
class CMMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
|
||||
|
||||
|
||||
class CMMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CMMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
Binary file not shown.
|
@ -0,0 +1,270 @@
|
|||
{
|
||||
"agronomy": {
|
||||
"name": "农学",
|
||||
"category": "Other"
|
||||
},
|
||||
"anatomy": {
|
||||
"name": "解剖学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"ancient_chinese": {
|
||||
"name": "古汉语",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"arts": {
|
||||
"name": "艺术学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"astronomy": {
|
||||
"name": "天文学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"business_ethics": {
|
||||
"name": "商业伦理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"chinese_civil_service_exam": {
|
||||
"name": "中国公务员考试",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"chinese_driving_rule": {
|
||||
"name": "中国驾驶规则",
|
||||
"category": "Other"
|
||||
},
|
||||
"chinese_food_culture": {
|
||||
"name": "中国饮食文化",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"chinese_foreign_policy": {
|
||||
"name": "中国外交政策",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"chinese_history": {
|
||||
"name": "中国历史",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"chinese_literature": {
|
||||
"name": "中国文学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"chinese_teacher_qualification": {
|
||||
"name": "中国教师资格",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"college_actuarial_science": {
|
||||
"name": "大学精算学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_education": {
|
||||
"name": "大学教育学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"college_engineering_hydrology": {
|
||||
"name": "大学工程水文学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_law": {
|
||||
"name": "大学法律",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"college_mathematics": {
|
||||
"name": "大学数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"college_medical_statistics": {
|
||||
"name": "大学医学统计",
|
||||
"category": "STEM"
|
||||
},
|
||||
"clinical_knowledge": {
|
||||
"name": "临床知识",
|
||||
"category": "Other"
|
||||
},
|
||||
"college_medicine": {
|
||||
"name": "大学医学",
|
||||
"category": "Other"
|
||||
},
|
||||
"computer_science": {
|
||||
"name": "计算机科学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"computer_security": {
|
||||
"name": "计算机安全",
|
||||
"category": "Other"
|
||||
},
|
||||
"conceptual_physics": {
|
||||
"name": "概念物理学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"construction_project_management": {
|
||||
"name": "建设工程管理",
|
||||
"category": "Other"
|
||||
},
|
||||
"economics": {
|
||||
"name": "经济学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"education": {
|
||||
"name": "教育学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"elementary_chinese": {
|
||||
"name": "小学语文",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"elementary_commonsense": {
|
||||
"name": "小学常识",
|
||||
"category": "Other"
|
||||
},
|
||||
"elementary_information_and_technology": {
|
||||
"name": "小学信息技术",
|
||||
"category": "Other"
|
||||
},
|
||||
"electrical_engineering": {
|
||||
"name": "电气工程",
|
||||
"category": "STEM"
|
||||
},
|
||||
"elementary_mathematics": {
|
||||
"name": "初等数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"ethnology": {
|
||||
"name": "民族学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"food_science": {
|
||||
"name": "食品科学",
|
||||
"category": "Other"
|
||||
},
|
||||
"genetics": {
|
||||
"name": "遗传学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"global_facts": {
|
||||
"name": "全球事实",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"high_school_biology": {
|
||||
"name": "高中生物",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_chemistry": {
|
||||
"name": "高中化学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_geography": {
|
||||
"name": "高中地理",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"high_school_mathematics": {
|
||||
"name": "高中数学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_physics": {
|
||||
"name": "高中物理学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"high_school_politics": {
|
||||
"name": "高中政治",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"human_sexuality": {
|
||||
"name": "人类性行为",
|
||||
"category": "Other"
|
||||
},
|
||||
"international_law": {
|
||||
"name": "国际法学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"journalism": {
|
||||
"name": "新闻学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"jurisprudence": {
|
||||
"name": "法理学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"legal_and_moral_basis": {
|
||||
"name": "法律与道德基础",
|
||||
"category": "Other"
|
||||
},
|
||||
"logical": {
|
||||
"name": "逻辑学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"machine_learning": {
|
||||
"name": "机器学习",
|
||||
"category": "STEM"
|
||||
},
|
||||
"management": {
|
||||
"name": "管理学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"marketing": {
|
||||
"name": "市场营销",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"marxist_theory": {
|
||||
"name": "马克思主义理论",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"modern_chinese": {
|
||||
"name": "现代汉语",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"nutrition": {
|
||||
"name": "营养学",
|
||||
"category": "Other"
|
||||
},
|
||||
"philosophy": {
|
||||
"name": "哲学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"professional_accounting": {
|
||||
"name": "专业会计",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"professional_law": {
|
||||
"name": "专业法学",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"professional_medicine": {
|
||||
"name": "专业医学",
|
||||
"category": "Other"
|
||||
},
|
||||
"professional_psychology": {
|
||||
"name": "专业心理学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"public_relations": {
|
||||
"name": "公共关系",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"security_study": {
|
||||
"name": "安全研究",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"sociology": {
|
||||
"name": "社会学",
|
||||
"category": "Social Sciences"
|
||||
},
|
||||
"sports_science": {
|
||||
"name": "体育学",
|
||||
"category": "Other"
|
||||
},
|
||||
"traditional_chinese_medicine": {
|
||||
"name": "中医中药",
|
||||
"category": "Other"
|
||||
},
|
||||
"virology": {
|
||||
"name": "病毒学",
|
||||
"category": "STEM"
|
||||
},
|
||||
"world_history": {
|
||||
"name": "世界历史",
|
||||
"category": "Humanities"
|
||||
},
|
||||
"world_religions": {
|
||||
"name": "世界宗教",
|
||||
"category": "Humanities"
|
||||
}
|
||||
}
|
|
@ -1,16 +1,15 @@
|
|||
# coding=utf-8
|
||||
# Evaluates the performance of pre-trained models.
|
||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm, trange
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
|
@ -30,6 +29,7 @@ class EvalTemplate:
|
|||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
|
@ -49,7 +49,6 @@ class EvalTemplate:
|
|||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
|
||||
if len(history):
|
||||
random.shuffle(history)
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
|
@ -65,12 +64,14 @@ eval_templates = {
|
|||
"en": EvalTemplate(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: "
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
),
|
||||
"zh": EvalTemplate(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:"
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -79,9 +80,8 @@ eval_templates = {
|
|||
def batch_inference(
|
||||
chat_model: ChatModel,
|
||||
batch_input: Dict[str, torch.Tensor],
|
||||
lang: Literal["zh", "en"]
|
||||
prefix_char: str
|
||||
) -> List[str]:
|
||||
prefix_char = "\n" if lang == "zh" else " "
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
|
@ -108,7 +108,8 @@ def evaluate(
|
|||
split: Optional[Literal["validation", "test"]] = "validation",
|
||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
||||
n_shot: Optional[int] = 5,
|
||||
batch_size: Optional[int] = 4
|
||||
batch_size: Optional[int] = 4,
|
||||
save_name: Optional[str] = None
|
||||
):
|
||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
||||
categorys = json.load(f)
|
||||
|
@ -119,25 +120,25 @@ def evaluate(
|
|||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "left"
|
||||
eval_template = eval_templates[lang]
|
||||
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
|
||||
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
"STEM": np.array([], dtype="bool"),
|
||||
"Social Sciences": np.array([], dtype="bool"),
|
||||
"Humanities": np.array([], dtype="bool"),
|
||||
"Other": np.array([], dtype="bool")
|
||||
subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"]
|
||||
}
|
||||
overall_corrects = np.array([], dtype="bool")
|
||||
|
||||
pbar = tqdm(categorys.keys())
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, labels = [], []
|
||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
||||
for i in range(len(dataset[split])):
|
||||
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
|
||||
query, resp, history = eval_template.format_example(
|
||||
target_data=dataset[split][i],
|
||||
support_set=dataset["train"].select(range(min(n_shot, len(dataset["train"])))),
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=chat_model.template.use_history
|
||||
)
|
||||
|
@ -154,23 +155,33 @@ def evaluate(
|
|||
labels.append(resp)
|
||||
|
||||
outputs = []
|
||||
for i in range(0, len(inputs), batch_size):
|
||||
for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
|
||||
batch_input = chat_model.tokenizer.pad(
|
||||
inputs[i : i + batch_size],
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt"
|
||||
).to(chat_model.model.device)
|
||||
preds = batch_inference(chat_model, batch_input, lang)
|
||||
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0)
|
||||
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||
|
||||
print("Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)))
|
||||
score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects))
|
||||
for category_name, category_correct in category_corrects.items():
|
||||
print(" {} - {:.2f}".format(category_name, 100 * np.mean(category_correct)))
|
||||
if len(category_correct):
|
||||
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
|
||||
print(score_info)
|
||||
if save_name is not None:
|
||||
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue