From 5a0a9daf74d9bad52927340f70df11bdcda79c3e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 26 Jun 2023 23:30:18 +0800 Subject: [PATCH] Create evaluate.py --- evaluation/evaluate.py | 116 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 evaluation/evaluate.py diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py new file mode 100644 index 00000000..ab793b8f --- /dev/null +++ b/evaluation/evaluate.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Evaluates fine-tuned models automatically. +# Usage: python evaluate.py --evalset ceval/ceval-exam:law --split dev --api_base http://localhost:8000/v1 --task_type choice + + +import os +import fire +import json +import openai +from tqdm import tqdm +from typing import Literal, Optional +from datasets import load_dataset + + +EXT2TYPE = { + "csv": "csv", + "json": "json", + "jsonl": "json" +} + + +def format_example_choice(examples): + model_inputs = {"query": [], "label": []} + task_template = "请从ABCD四个选项中选出正确的选项,仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:" + for i in range(len(examples["id"])): + query = task_template.format( + question=examples["question"][i], + A=examples["A"][i], + B=examples["B"][i], + C=examples["C"][i], + D=examples["D"][i] + ) + label = examples["answer"][i] + model_inputs["query"].append(query) + model_inputs["label"].append(label) + return model_inputs + + +def format_example_cloze(examples): + model_inputs = {"query": [], "label": []} + task_template = "请选择正确的答案填空,仅输出正确的选项。\n{question}\n选项:{A}\n{B}\n{C}\n{D}\n答案:" + for i in range(len(examples["id"])): + query = task_template.format( + question=examples["question"][i], + A=examples["A"][i], + B=examples["B"][i], + C=examples["C"][i], + D=examples["D"][i] + ) + label = examples[examples["answer"][i]][i] + model_inputs["query"].append(query) + model_inputs["label"].append(label) + return model_inputs + + +TASK_DICT = { + "choice": format_example_choice, + "cloze": format_example_cloze +} + + +def evaluate( + evalset: str, + api_base: str, + split: Optional[str] = "val", + task_type: Optional[Literal["choice", "cloze"]] = "choice", + n_samples: Optional[int] = 20 +): + + openai.api_base = api_base + openai.api_key = "none" + + if os.path.isfile(evalset): + dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"] + else: + if ":" in evalset: + evalset, subset = evalset.split(":") + dataset = load_dataset(evalset, subset, split=split) + else: + dataset = load_dataset(evalset, split=split) + + n_samples = min(len(dataset), n_samples) + + dataset = dataset.map(TASK_DICT[task_type], batched=True) + dataset = dataset.select(range(n_samples)) + + n_correct = 0 + predictions = [] + for example in tqdm(dataset): + query = example["query"] + label = example["label"] + predict = openai.ChatCompletion.create( + model="main", + messages=[{"role": "user", "content": query}], + temperature=0.01, + max_new_tokens=20 + ).choices[0].message.content + + if task_type == "choice" and predict[0].lower() == label[0].lower(): + n_correct += 1 + if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]: + n_correct += 1 + + predictions.append({ + "query": query, + "label": label, + "predict": predict + }) + + print("Result: {}/{}\nAccuracy: {:.2f}%".format(n_correct, n_samples, n_correct / n_samples * 100)) + with open("result.json", "w", encoding="utf-8") as f: + json.dump(predictions, f, indent=2, ensure_ascii=False) + + +if __name__ == "__main__": + fire.Fire(evaluate)