This commit is contained in:
hiyouga 2023-09-23 00:42:23 +08:00
parent 465ee8119a
commit 2340b0d7df
1 changed files with 8 additions and 6 deletions

View File

@ -76,16 +76,18 @@ eval_templates = {
@torch.inference_mode()
def batch_inference(
chat_model: ChatModel,
batch_input: Dict[str, torch.Tensor]
batch_input: Dict[str, torch.Tensor],
lang: Literal["zh", "en"]
) -> List[str]:
prefix_char = "\n" if lang == "zh" else " "
logits = chat_model.model(**batch_input).logits
probs = torch.nn.functional.softmax(
torch.stack(
[
logits[:, -1, chat_model.tokenizer.encode("\nA")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nB")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nC")[-1]],
logits[:, -1, chat_model.tokenizer.encode("\nD")[-1]]
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "A")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "B")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "C")[-1]],
logits[:, -1, chat_model.tokenizer.encode(prefix_char + "D")[-1]]
],
dim=-1
),
@ -156,7 +158,7 @@ def evaluate(
return_attention_mask=True,
return_tensors="pt"
).to(chat_model.model.device)
preds = batch_inference(chat_model, batch_input)
preds = batch_inference(chat_model, batch_input, lang)
outputs += preds
corrects = (np.array(outputs) == np.array(labels))