update evaluator
This commit is contained in:
parent
c907d81667
commit
0012762b04
|
@ -26,9 +26,7 @@ class Evaluator:
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = [
|
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
|
||||||
]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||||
|
|
|
@ -10,7 +10,6 @@ class EvalTemplate:
|
||||||
system: str
|
system: str
|
||||||
choice: str
|
choice: str
|
||||||
answer: str
|
answer: str
|
||||||
prefix: str
|
|
||||||
|
|
||||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -42,8 +41,8 @@ class EvalTemplate:
|
||||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
||||||
|
|
||||||
|
|
||||||
def get_eval_template(name: str) -> "EvalTemplate":
|
def get_eval_template(name: str) -> "EvalTemplate":
|
||||||
|
@ -56,8 +55,7 @@ _register_eval_template(
|
||||||
name="en",
|
name="en",
|
||||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\nAnswer: ",
|
answer="\nAnswer:",
|
||||||
prefix=" ",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,5 +64,4 @@ _register_eval_template(
|
||||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||||
choice="\n{choice}. {content}",
|
choice="\n{choice}. {content}",
|
||||||
answer="\n答案:",
|
answer="\n答案:",
|
||||||
prefix=" ",
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
from llamafactory.eval.template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_template_en():
|
||||||
|
support_set = [
|
||||||
|
{
|
||||||
|
"question": "Fewshot question",
|
||||||
|
"A": "Fewshot1",
|
||||||
|
"B": "Fewshot2",
|
||||||
|
"C": "Fewshot3",
|
||||||
|
"D": "Fewshot4",
|
||||||
|
"answer": "B",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
example = {
|
||||||
|
"question": "Target question",
|
||||||
|
"A": "Target1",
|
||||||
|
"B": "Target2",
|
||||||
|
"C": "Target3",
|
||||||
|
"D": "Target4",
|
||||||
|
"answer": "C",
|
||||||
|
}
|
||||||
|
template = get_eval_template(name="en")
|
||||||
|
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
|
||||||
|
assert messages == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"The following are multiple choice questions (with answers) about SubName.\n\n"
|
||||||
|
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "B"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "C"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_eval_template_zh():
|
||||||
|
support_set = [
|
||||||
|
{
|
||||||
|
"question": "示例问题",
|
||||||
|
"A": "示例答案1",
|
||||||
|
"B": "示例答案2",
|
||||||
|
"C": "示例答案3",
|
||||||
|
"D": "示例答案4",
|
||||||
|
"answer": "B",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
example = {
|
||||||
|
"question": "目标问题",
|
||||||
|
"A": "目标答案1",
|
||||||
|
"B": "目标答案2",
|
||||||
|
"C": "目标答案3",
|
||||||
|
"D": "目标答案4",
|
||||||
|
"answer": "C",
|
||||||
|
}
|
||||||
|
template = get_eval_template(name="zh")
|
||||||
|
messages = template.format_example(example, support_set=support_set, subject_name="主题")
|
||||||
|
assert messages == [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
|
||||||
|
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "B"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "C"},
|
||||||
|
]
|
Loading…
Reference in New Issue