update ppl script

This commit is contained in:
hiyouga 2024-05-04 22:13:14 +08:00
parent 3a666832c1
commit 76a077bdce
1 changed files with 5 additions and 4 deletions

View File

@ -1,6 +1,6 @@
# coding=utf-8
# Calculates the ppl of pre-trained models.
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
# Calculates the ppl on the dataset of the pre-trained models.
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
import json
from typing import Dict
@ -19,6 +19,7 @@ from llmtuner.model import load_model, load_tokenizer
def cal_ppl(
model_name_or_path: str,
save_name: str,
batch_size: int = 4,
stage: str = "sft",
dataset: str = "alpaca_en",
@ -69,10 +70,10 @@ def cal_ppl(
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
perplexities.extend(sentence_logps.exp().tolist())
with open("ppl.json", "w", encoding="utf-8") as f:
with open(save_name, "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Perplexities have been saved at ppl.json.")
print("Perplexities have been saved at {}.".format(save_name))
if __name__ == "__main__":