fix MMLU
This commit is contained in:
parent
465ee8119a
commit
2340b0d7df
|
@ -76,16 +76,18 @@ eval_templates = {
|
|||
@torch.inference_mode()
|
||||
def batch_inference(
|
||||
chat_model: ChatModel,
|
||||
batch_input: Dict[str, torch.Tensor]
|
||||
batch_input: Dict[str, torch.Tensor],
|
||||
lang: Literal["zh", "en"]
|
||||
) -> List[str]:
|
||||
prefix_char = "\n" if lang == "zh" else " "
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
[
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
|
||||
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
|
||||
],
|
||||
dim=-1
|
||||
),
|
||||
|
@ -156,7 +158,7 @@ def evaluate(
|
|||
return_attention_mask=True,
|
||||
return_tensors="pt"
|
||||
).to(chat_model.model.device)
|
||||
preds = batch_inference(chat_model, batch_input)
|
||||
preds = batch_inference(chat_model, batch_input, lang)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
|
|
Loading…
Reference in New Issue