add test scripts
This commit is contained in:
parent
d46977edf5
commit
26912cd816
|
@ -10,7 +10,7 @@ import fire
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
@ -24,26 +24,35 @@ BASE_BS = 4_000_000 # from llama paper
|
|||
|
||||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
dataset: str,
|
||||
cutoff_len: int, # i.e. maximum input length during training
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
is_mistral: bool, # mistral model uses a smaller learning rate,
|
||||
stage: Optional[str] = "sft",
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
|
||||
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
stage=stage,
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template="default",
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
|
||||
)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
# Calculates the distribution of the input lengths in the dataset.
|
||||
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from llmtuner.data import get_dataset
|
||||
from llmtuner.hparams import get_train_args
|
||||
from llmtuner.model import load_model_and_tokenizer
|
||||
|
||||
|
||||
def length_cdf(
|
||||
model_name_or_path: str,
|
||||
dataset: Optional[str] = "alpaca_en",
|
||||
dataset_dir: Optional[str] = "data",
|
||||
template: Optional[str] = "default",
|
||||
interval: Optional[int] = 1000,
|
||||
):
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
|
||||
dict(
|
||||
stage="sft",
|
||||
model_name_or_path=model_name_or_path,
|
||||
dataset=dataset,
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=1_000_000,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
)
|
||||
)
|
||||
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
|
||||
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
length_dict[len(sample) // interval * interval] += 1
|
||||
|
||||
length_tuples = list(length_dict.items())
|
||||
length_tuples.sort()
|
||||
count_accu, prob_accu = 0, 0
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(length_cdf)
|
Loading…
Reference in New Issue