add cal_ppl script

This commit is contained in:
hiyouga 2024-05-04 22:02:25 +08:00
parent 57a39783d1
commit 3a666832c1
4 changed files with 95 additions and 22 deletions

View File

@ -3,24 +3,22 @@
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
from typing import Optional
import fire
import torch
from deepspeed.accelerator import get_accelerator # type: ignore
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
from llmtuner.chat import ChatModel
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
flash_attn: Optional[bool] = False,
batch_size: int = 1,
seq_length: int = 256,
flash_attn: str = "auto",
):
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="vanilla", flash_attn=flash_attn))
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)

View File

@ -4,7 +4,6 @@
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
from typing import Optional
import fire
import torch
@ -25,12 +24,12 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: Optional[str] = "sft",
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
stage: str = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
):
model_args, data_args, training_args, _, _ = get_train_args(
dict(
@ -54,9 +53,7 @@ def calculate_lr(
else:
raise NotImplementedError
dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
)
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()

79
scripts/cal_ppl.py Normal file
View File

@ -0,0 +1,79 @@
# coding=utf-8
# Calculates the ppl of pre-trained models.
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
import json
from typing import Dict
import fire
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model, load_tokenizer
def cal_ppl(
model_name_or_path: str,
batch_size: int = 4,
stage: str = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024,
train_on_prompt: bool = False,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
train_on_prompt=train_on_prompt,
output_dir="dummy_dir",
overwrite_cache=True,
)
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")
perplexities = []
batch: Dict[str, "torch.Tensor"]
with torch.no_grad():
for batch in tqdm(dataloader):
batch = batch.to(model.device)
outputs = model(**batch)
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
loss_mask = shift_labels != IGNORE_INDEX
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
flatten_labels = shift_labels.contiguous().view(-1)
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
perplexities.extend(sentence_logps.exp().tolist())
with open("ppl.json", "w", encoding="utf-8") as f:
json.dump(perplexities, f, indent=2)
print("Perplexities have been saved at ppl.json.")
if __name__ == "__main__":
fire.Fire(cal_ppl)

View File

@ -3,7 +3,6 @@
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
from collections import defaultdict
from typing import Optional
import fire
from tqdm import tqdm
@ -15,10 +14,10 @@ from llmtuner.model import load_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
interval: Optional[int] = 1000,
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
interval: int = 1000,
):
model_args, data_args, training_args, _, _ = get_train_args(
dict(