From 4f3772b3424de8b980c5b27e027eeff55010b14a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 26 Jun 2023 23:41:33 +0800 Subject: [PATCH] Update evaluate.py --- evaluation/evaluate.py | 48 +++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index ab793b8f..cb709163 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -1,6 +1,7 @@ # coding=utf-8 # Evaluates fine-tuned models automatically. # Usage: python evaluate.py --evalset ceval/ceval-exam:law --split dev --api_base http://localhost:8000/v1 --task_type choice +# dataset format: question (string), A (string), B (string), C (string), D (string), answer Literal["A", "B", "C", "D"] import os @@ -12,13 +13,6 @@ from typing import Literal, Optional from datasets import load_dataset -EXT2TYPE = { - "csv": "csv", - "json": "json", - "jsonl": "json" -} - - def format_example_choice(examples): model_inputs = {"query": [], "label": []} task_template = "请从ABCD四个选项中选出正确的选项,仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:" @@ -53,9 +47,28 @@ def format_example_cloze(examples): return model_inputs +def format_example_openqa(examples): + model_inputs = {"query": [], "label": []} + task_template = "回答以下问题:{question}\n答案:" + for i in range(len(examples["id"])): + query = task_template.format(question=examples["question"][i]) + label = examples[examples["answer"][i]][i] + model_inputs["query"].append(query) + model_inputs["label"].append(label) + return model_inputs + + TASK_DICT = { "choice": format_example_choice, - "cloze": format_example_cloze + "cloze": format_example_cloze, + "openqa": format_example_openqa +} + + +EXT2TYPE = { + "csv": "csv", + "json": "json", + "jsonl": "json" } @@ -63,7 +76,7 @@ def evaluate( evalset: str, api_base: str, split: Optional[str] = "val", - task_type: Optional[Literal["choice", "cloze"]] = "choice", + task_type: Optional[Literal["choice", "cloze", "openqa"]] = "choice", n_samples: Optional[int] = 20 ): @@ -72,12 +85,11 @@ def evaluate( if os.path.isfile(evalset): dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"] + elif ":" in evalset: + evalset, subset = evalset.split(":") + dataset = load_dataset(evalset, subset, split=split) else: - if ":" in evalset: - evalset, subset = evalset.split(":") - dataset = load_dataset(evalset, subset, split=split) - else: - dataset = load_dataset(evalset, split=split) + dataset = load_dataset(evalset, split=split) n_samples = min(len(dataset), n_samples) @@ -87,12 +99,12 @@ def evaluate( n_correct = 0 predictions = [] for example in tqdm(dataset): - query = example["query"] - label = example["label"] + query, label = example["query"], example["label"] predict = openai.ChatCompletion.create( - model="main", + model="default", messages=[{"role": "user", "content": query}], temperature=0.01, + top_p=0.01, max_new_tokens=20 ).choices[0].message.content @@ -100,6 +112,8 @@ def evaluate( n_correct += 1 if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]: n_correct += 1 + if task_type == "openqa" and label in predict: + n_correct += 1 predictions.append({ "query": query,