add cal_lr.py

This commit is contained in:
hiyouga 2023-11-14 20:58:37 +08:00
parent d125ef5535
commit 42c8fc4fb9
4 changed files with 67 additions and 6 deletions

View File

@ -84,10 +84,9 @@ def load_model_and_tokenizer(
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is not None: # for training
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # for evaluation, priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":

View File

@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
from llmtuner import ChatModel
def calculate(
def calculate_flops(
model_name_or_path: str,
batch_size: Optional[int] = 1,
seq_length: Optional[int] = 256,
@ -41,4 +41,4 @@ def calculate(
if __name__ == "__main__":
fire.Fire(calculate)
fire.Fire(calculate_flops)

63
tests/cal_lr.py Normal file
View File

@ -0,0 +1,63 @@
# coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import fire
import math
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
BASE_LR = 3e-4
BASE_BS = 4_000_000
def calculate_lr(
model_name_or_path: str,
dataset: str,
cutoff_len: int,
batch_size: int
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
template="default",
cutoff_len=cutoff_len,
output_dir="dummy_dir",
fp16=True
))
trainset = get_dataset(model_args, data_args)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
dataloader = DataLoader(
dataset=trainset,
batch_size=batch_size,
shuffle=True,
collate_fn=data_collator,
pin_memory=True
)
valid_tokens, total_tokens = 0, 0
for batch in tqdm(dataloader):
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
total_tokens += torch.numel(batch["labels"])
batch_max_len = cutoff_len * batch_size # max tokens in a batch
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len
))
if __name__ == "__main__":
fire.Fire(calculate_lr)

View File

@ -4,7 +4,6 @@
# --max_length 1024 --max_samples 1024
# dataset format: instruction (string), input (string), output (string), history (List[string])
import fire
from datasets import load_dataset
from transformers import AutoTokenizer