forked from p04798526/LLaMA-Factory-Mirror
update ppl script
This commit is contained in:
parent
3a666832c1
commit
76a077bdce
|
@ -1,6 +1,6 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Calculates the ppl of pre-trained models.
|
# Calculates the ppl on the dataset of the pre-trained models.
|
||||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
@ -19,6 +19,7 @@ from llmtuner.model import load_model, load_tokenizer
|
||||||
|
|
||||||
def cal_ppl(
|
def cal_ppl(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
save_name: str,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
stage: str = "sft",
|
stage: str = "sft",
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en",
|
||||||
|
@ -69,10 +70,10 @@ def cal_ppl(
|
||||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||||
perplexities.extend(sentence_logps.exp().tolist())
|
perplexities.extend(sentence_logps.exp().tolist())
|
||||||
|
|
||||||
with open("ppl.json", "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
json.dump(perplexities, f, indent=2)
|
json.dump(perplexities, f, indent=2)
|
||||||
|
|
||||||
print("Perplexities have been saved at ppl.json.")
|
print("Perplexities have been saved at {}.".format(save_name))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue