add avg ppl
This commit is contained in:
parent
76a077bdce
commit
25aeaae51b
|
@ -54,6 +54,7 @@ def cal_ppl(
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
total_ppl = 0
|
||||||
perplexities = []
|
perplexities = []
|
||||||
batch: Dict[str, "torch.Tensor"]
|
batch: Dict[str, "torch.Tensor"]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -68,11 +69,13 @@ def cal_ppl(
|
||||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||||
|
total_ppl += sentence_logps.exp().sum().item()
|
||||||
perplexities.extend(sentence_logps.exp().tolist())
|
perplexities.extend(sentence_logps.exp().tolist())
|
||||||
|
|
||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
|
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||||
print("Perplexities have been saved at {}.".format(save_name))
|
print("Perplexities have been saved at {}.".format(save_name))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue