diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index df22ca70..4e74f5f9 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -76,16 +76,18 @@ eval_templates = { @torch.inference_mode() def batch_inference( chat_model: ChatModel, - batch_input: Dict[str, torch.Tensor] + batch_input: Dict[str, torch.Tensor], + lang: Literal["zh", "en"] ) -> List[str]: + prefix_char = "\n" if lang == "zh" else " " logits = chat_model.model(**batch_input).logits probs = torch.nn.functional.softmax( torch.stack( [ - logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]], - logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]], - logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]], - logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]] + logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]], + logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]], + logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]], + logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]] ], dim=-1 ), @@ -156,7 +158,7 @@ def evaluate( return_attention_mask=True, return_tensors="pt" ).to(chat_model.model.device) - preds = batch_inference(chat_model, batch_input) + preds = batch_inference(chat_model, batch_input, lang) outputs += preds corrects = (np.array(outputs) == np.array(labels))