update evaluator
This commit is contained in:
parent
c907d81667
commit
0012762b04
|
@ -26,9 +26,7 @@ class Evaluator:
|
|||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||
self.model = load_model(self.tokenizer, self.model_args, finetuning_args)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = [
|
||||
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES
|
||||
]
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
|
|
|
@ -10,7 +10,6 @@ class EvalTemplate:
|
|||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
|
@ -42,8 +41,8 @@ class EvalTemplate:
|
|||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str, prefix: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer, prefix=prefix)
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
eval_templates[name] = EvalTemplate(system=system, choice=choice, answer=answer)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
|
@ -56,8 +55,7 @@ _register_eval_template(
|
|||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" ",
|
||||
answer="\nAnswer:",
|
||||
)
|
||||
|
||||
|
||||
|
@ -66,5 +64,4 @@ _register_eval_template(
|
|||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix=" ",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
from llamafactory.eval.template import get_eval_template
|
||||
|
||||
|
||||
def test_eval_template_en():
|
||||
support_set = [
|
||||
{
|
||||
"question": "Fewshot question",
|
||||
"A": "Fewshot1",
|
||||
"B": "Fewshot2",
|
||||
"C": "Fewshot3",
|
||||
"D": "Fewshot4",
|
||||
"answer": "B",
|
||||
}
|
||||
]
|
||||
example = {
|
||||
"question": "Target question",
|
||||
"A": "Target1",
|
||||
"B": "Target2",
|
||||
"C": "Target3",
|
||||
"D": "Target4",
|
||||
"answer": "C",
|
||||
}
|
||||
template = get_eval_template(name="en")
|
||||
messages = template.format_example(example, support_set=support_set, subject_name="SubName")
|
||||
assert messages == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"The following are multiple choice questions (with answers) about SubName.\n\n"
|
||||
"Fewshot question\nA. Fewshot1\nB. Fewshot2\nC. Fewshot3\nD. Fewshot4\nAnswer:"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Target question\nA. Target1\nB. Target2\nC. Target3\nD. Target4\nAnswer:",
|
||||
},
|
||||
{"role": "assistant", "content": "C"},
|
||||
]
|
||||
|
||||
|
||||
def test_eval_template_zh():
|
||||
support_set = [
|
||||
{
|
||||
"question": "示例问题",
|
||||
"A": "示例答案1",
|
||||
"B": "示例答案2",
|
||||
"C": "示例答案3",
|
||||
"D": "示例答案4",
|
||||
"answer": "B",
|
||||
}
|
||||
]
|
||||
example = {
|
||||
"question": "目标问题",
|
||||
"A": "目标答案1",
|
||||
"B": "目标答案2",
|
||||
"C": "目标答案3",
|
||||
"D": "目标答案4",
|
||||
"answer": "C",
|
||||
}
|
||||
template = get_eval_template(name="zh")
|
||||
messages = template.format_example(example, support_set=support_set, subject_name="主题")
|
||||
assert messages == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"以下是中国关于主题考试的单项选择题,请选出其中的正确答案。\n\n"
|
||||
"示例问题\nA. 示例答案1\nB. 示例答案2\nC. 示例答案3\nD. 示例答案4\n答案:"
|
||||
),
|
||||
},
|
||||
{"role": "assistant", "content": "B"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "目标问题\nA. 目标答案1\nB. 目标答案2\nC. 目标答案3\nD. 目标答案4\n答案:",
|
||||
},
|
||||
{"role": "assistant", "content": "C"},
|
||||
]
|
Loading…
Reference in New Issue