update evaluator

This commit is contained in:
hiyouga 2024-06-10 23:56:00 +08:00
parent c907d81667
commit 0012762b04
3 changed files with 81 additions and 9 deletions

View File

@ -26,9 +26,7 @@ class Evaluator:
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [ self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:

View File

@ -10,7 +10,6 @@ class EvalTemplate:
system: str system: str
choice: str choice: str
answer: str answer: str
prefix: str
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]: def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
r""" r"""
@ -42,8 +41,8 @@ class EvalTemplate:
eval_templates: Dict[str, "EvalTemplate"] = {} eval_templates: Dict[str, "EvalTemplate"] = {}
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None: def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix) eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
def get_eval_template(name: str) -> "EvalTemplate": def get_eval_template(name: str) -> "EvalTemplate":
@ -56,8 +55,7 @@ _register_eval_template(
name="en", name="en",
system="The following are multiple choice questions (with answers) about {subject}.\n\n", system="The following are multiple choice questions (with answers) about {subject}.\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\nAnswer: ", answer="\nAnswer:",
prefix=" ",
) )
@ -66,5 +64,4 @@ _register_eval_template(
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n", system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
choice="\n{choice}. {content}", choice="\n{choice}. {content}",
answer="\n答案:", answer="\n答案:",
prefix=" ",
) )

View File

@ -0,0 +1,77 @@
from llamafactory.eval.template import get_eval_template
def test_eval_template_en():
support_set = [
{
"question": "Fewshot question",
"A": "Fewshot1",
"B": "Fewshot2",
"C": "Fewshot3",
"D": "Fewshot4",
"answer": "B",
}
]
example = {
"question": "Target question",
"A": "Target1",
"B": "Target2",
"C": "Target3",
"D": "Target4",
"answer": "C",
}
template = get_eval_template(name="en")
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
assert messages == [
{
"role": "user",
"content": (
"The following are multiple choice questions (with answers) about SubName.\n\n"
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
},
{"role": "assistant", "content": "C"},
]
def test_eval_template_zh():
support_set = [
{
"question": "示例问题",
"A": "示例答案1",
"B": "示例答案2",
"C": "示例答案3",
"D": "示例答案4",
"answer": "B",
}
]
example = {
"question": "目标问题",
"A": "目标答案1",
"B": "目标答案2",
"C": "目标答案3",
"D": "目标答案4",
"answer": "C",
}
template = get_eval_template(name="zh")
messages = template.format_example(example, support_set=support_set, subject_name="主题")
assert messages == [
{
"role": "user",
"content": (
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
),
},
{"role": "assistant", "content": "B"},
{
"role": "user",
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
},
{"role": "assistant", "content": "C"},
]