add averaging in evaluation

This commit is contained in:
hiyouga 2023-10-10 23:16:31 +08:00
parent be420e4179
commit 5310e4d182
1 changed files with 44 additions and 47 deletions

View File

@ -9,10 +9,11 @@ import fire
import json import json
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm, trange from collections import Counter
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from datasets import load_dataset from datasets import load_dataset
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm, trange
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner import ChatModel from llmtuner import ChatModel
@ -86,10 +87,8 @@ def batch_inference(
probs = torch.nn.functional.softmax( probs = torch.nn.functional.softmax(
torch.stack( torch.stack(
[ [
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]], logits[:, -1, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]], for choice in choices
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
], ],
dim=-1 dim=-1
), ),
@ -108,11 +107,12 @@ def evaluate(
split: Optional[Literal["validation", "test"]] = "validation", split: Optional[Literal["validation", "test"]] = "validation",
lang: Optional[Literal["zh", "en"]] = "zh", lang: Optional[Literal["zh", "en"]] = "zh",
n_shot: Optional[int] = 5, n_shot: Optional[int] = 5,
n_avg: Optional[int] = 1,
batch_size: Optional[int] = 4, batch_size: Optional[int] = 4,
save_name: Optional[str] = None save_name: Optional[str] = None
): ):
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f: with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
categorys = json.load(f) categorys: Dict[str, Dict[str, str]] = json.load(f)
chat_model = ChatModel(dict( chat_model = ChatModel(dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
@ -124,17 +124,17 @@ def evaluate(
assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted." assert chat_model.tokenizer.padding_side == "left", "only left-padded tensor can be accepted."
category_corrects: Dict[str, np.ndarray] = { category_corrects: Dict[str, np.ndarray] = {
subj: np.array([], dtype="bool") for subj in ["STEM", "Social Sciences", "Humanities", "Other"] subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
} }
overall_corrects = np.array([], dtype="bool")
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
pbar.set_postfix_str(categorys[subject]["name"])
inputs, labels = [], []
dataset = load_dataset(os.path.join(dataset_dir, task), subject) dataset = load_dataset(os.path.join(dataset_dir, task), subject)
for i in range(len(dataset[split])): labels, answers, all_outputs = [], [], []
for epoch in range(n_avg):
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
inputs, outputs = [], []
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"])))) support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
query, resp, history = eval_template.format_example( query, resp, history = eval_template.format_example(
target_data=dataset[split][i], target_data=dataset[split][i],
@ -143,37 +143,34 @@ def evaluate(
use_history=chat_model.template.use_history use_history=chat_model.template.use_history
) )
input_ids, _ = chat_model.template.encode_oneturn( input_ids, _ = chat_model.template.encode_oneturn(
tokenizer=chat_model.tokenizer, tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
query=query,
resp=resp,
history=history
) )
inputs.append({ inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
"input_ids": input_ids, if epoch == 0:
"attention_mask": [1] * len(input_ids)
})
labels.append(resp) labels.append(resp)
outputs = [] for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
for i in trange(0, len(inputs), batch_size, desc="Processing batches", position=1, leave=False):
batch_input = chat_model.tokenizer.pad( batch_input = chat_model.tokenizer.pad(
inputs[i : i + batch_size], inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device) ).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input, eval_template.prefix) preds = batch_inference(chat_model, batch_input, eval_template.prefix)
outputs += preds outputs += preds
all_outputs.append(outputs)
corrects = (np.array(outputs) == np.array(labels)) for i in range(len(all_outputs[0])):
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
answers.append(count.most_common(1)[0][0])
corrects = (np.array(answers) == np.array(labels))
category_name = categorys[subject]["category"] category_name = categorys[subject]["category"]
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0) category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
overall_corrects = np.concatenate([overall_corrects, corrects], axis=0) category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
results[subject] = {str(i): outputs[i] for i in range(len(outputs))} results[subject] = {str(i): answers[i] for i in range(len(answers))}
score_info = "Average accuracy: {:.2f}".format(100 * np.mean(overall_corrects)) score_info = "\n".join([
for category_name, category_correct in category_corrects.items(): "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
if len(category_correct): for category_name, category_correct in category_corrects.items() if len(category_correct)
score_info += "\n{:>16}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) ])
print(score_info) print(score_info)
if save_name is not None: if save_name is not None: