add avg ppl

This commit is contained in:
hiyouga 2024-05-04 22:35:31 +08:00
parent 76a077bdce
commit 25aeaae51b
1 changed files with 3 additions and 0 deletions

View File

@ -54,6 +54,7 @@ def cal_ppl(
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none") criterion = torch.nn.CrossEntropyLoss(reduction="none")
total_ppl = 0
perplexities = [] perplexities = []
batch: Dict[str, "torch.Tensor"] batch: Dict[str, "torch.Tensor"]
with torch.no_grad(): with torch.no_grad():
@ -68,11 +69,13 @@ def cal_ppl(
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
total_ppl += sentence_logps.exp().sum().item()
perplexities.extend(sentence_logps.exp().tolist()) perplexities.extend(sentence_logps.exp().tolist())
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2) json.dump(perplexities, f, indent=2)
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
print("Perplexities have been saved at {}.".format(save_name)) print("Perplexities have been saved at {}.".format(save_name))