update ppl script
This commit is contained in:
parent
3a666832c1
commit
76a077bdce
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Calculates the ppl of pre-trained models.
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Calculates the ppl on the dataset of the pre-trained models.
|
||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
@ -19,6 +19,7 @@ from llmtuner.model import load_model, load_tokenizer
|
|||
|
||||
def cal_ppl(
|
||||
model_name_or_path: str,
|
||||
save_name: str,
|
||||
batch_size: int = 4,
|
||||
stage: str = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
|
@ -69,10 +70,10 @@ def cal_ppl(
|
|||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
perplexities.extend(sentence_logps.exp().tolist())
|
||||
|
||||
with open("ppl.json", "w", encoding="utf-8") as f:
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Perplexities have been saved at ppl.json.")
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue