diff --git a/tests/cal_lr.py b/tests/cal_lr.py index 8c5cd909..6decf0c2 100644 --- a/tests/cal_lr.py +++ b/tests/cal_lr.py @@ -10,7 +10,7 @@ import fire import torch from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import DataCollatorForSeq2Seq +from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq from llmtuner.data import get_dataset from llmtuner.extras.constants import IGNORE_INDEX @@ -24,26 +24,35 @@ BASE_BS = 4_000_000 # from llama paper def calculate_lr( model_name_or_path: str, - dataset: str, - cutoff_len: int, # i.e. maximum input length during training batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) - is_mistral: bool, # mistral model uses a smaller learning rate, + stage: Optional[str] = "sft", + dataset: Optional[str] = "alpaca_en", dataset_dir: Optional[str] = "data", + template: Optional[str] = "default", + cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training + is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate, ): model_args, data_args, training_args, finetuning_args, _ = get_train_args( dict( - stage="sft", + stage=stage, model_name_or_path=model_name_or_path, dataset=dataset, dataset_dir=dataset_dir, - template="default", + template=template, cutoff_len=cutoff_len, output_dir="dummy_dir", + overwrite_cache=True, ) ) _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) - trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") - data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) + trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage) + if stage == "pt": + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + elif stage == "sft": + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) + else: + raise NotImplementedError + dataloader = DataLoader( dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True ) diff --git a/tests/length_cdf.py b/tests/length_cdf.py new file mode 100644 index 00000000..d9cb06f5 --- /dev/null +++ b/tests/length_cdf.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Calculates the distribution of the input lengths in the dataset. +# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default + +from collections import defaultdict +from typing import Optional + +import fire +from tqdm import tqdm + +from llmtuner.data import get_dataset +from llmtuner.hparams import get_train_args +from llmtuner.model import load_model_and_tokenizer + + +def length_cdf( + model_name_or_path: str, + dataset: Optional[str] = "alpaca_en", + dataset_dir: Optional[str] = "data", + template: Optional[str] = "default", + interval: Optional[int] = 1000, +): + model_args, data_args, training_args, finetuning_args, _ = get_train_args( + dict( + stage="sft", + model_name_or_path=model_name_or_path, + dataset=dataset, + dataset_dir=dataset_dir, + template=template, + cutoff_len=1_000_000, + output_dir="dummy_dir", + overwrite_cache=True, + ) + ) + _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) + trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") + total_num = len(trainset) + length_dict = defaultdict(int) + for sample in tqdm(trainset["input_ids"]): + length_dict[len(sample) // interval * interval] += 1 + + length_tuples = list(length_dict.items()) + length_tuples.sort() + count_accu, prob_accu = 0, 0 + for length, count in length_tuples: + count_accu += count + prob_accu += count / total_num * 100 + print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval)) + + +if __name__ == "__main__": + fire.Fire(length_cdf)