diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 6c8c6174..06c2a43b 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -54,6 +54,7 @@ def cal_ppl( dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) criterion = torch.nn.CrossEntropyLoss(reduction="none") + total_ppl = 0 perplexities = [] batch: Dict[str, "torch.Tensor"] with torch.no_grad(): @@ -68,11 +69,13 @@ def cal_ppl( token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels) token_logps = token_logps.contiguous().view(shift_logits.size(0), -1) sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + total_ppl += sentence_logps.exp().sum().item() perplexities.extend(sentence_logps.exp().tolist()) with open(save_name, "w", encoding="utf-8") as f: json.dump(perplexities, f, indent=2) + print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities))) print("Perplexities have been saved at {}.".format(save_name))