add test scripts

This commit is contained in:
hiyouga 2024-02-19 02:09:13 +08:00
parent d46977edf5
commit 26912cd816
2 changed files with 69 additions and 8 deletions

View File

@ -10,7 +10,7 @@ import fire
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
@ -24,26 +24,35 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr( def calculate_lr(
model_name_or_path: str, model_name_or_path: str,
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate, stage: Optional[str] = "sft",
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data", dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
stage="sft", stage=stage,
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
dataset=dataset, dataset=dataset,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
template="default", template=template,
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
output_dir="dummy_dir", output_dir="dummy_dir",
overwrite_cache=True,
) )
) )
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError
dataloader = DataLoader( dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
) )

52
tests/length_cdf.py Normal file
View File

@ -0,0 +1,52 @@
# coding=utf-8
# Calculates the distribution of the input lengths in the dataset.
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
from collections import defaultdict
from typing import Optional
import fire
from tqdm import tqdm
from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
interval: Optional[int] = 1000,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
output_dir="dummy_dir",
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())
length_tuples.sort()
count_accu, prob_accu = 0, 0
for length, count in length_tuples:
count_accu += count
prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
if __name__ == "__main__":
fire.Fire(length_cdf)