From 42c8fc4fb970775159a68a123d5c7bedb701c8cf Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 Nov 2023 20:58:37 +0800 Subject: [PATCH] add cal_lr.py --- src/llmtuner/tuner/core/loader.py | 5 +-- tests/cal_flops.py | 4 +- tests/cal_lr.py | 63 +++++++++++++++++++++++++++++++ tests/quantize.py | 1 - 4 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 tests/cal_lr.py diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 3a51a698..38d5f71e 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -84,10 +84,9 @@ def load_model_and_tokenizer( tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) # Set model dtype - if model_args.compute_dtype is not None: # for training - setattr(config, "torch_dtype", model_args.compute_dtype) - else: # for evaluation, priority: bf16 > fp16 > fp32 + if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + setattr(config, "torch_dtype", model_args.compute_dtype) # Fix config (for Qwen) if getattr(config, "model_type", None) == "qwen": diff --git a/tests/cal_flops.py b/tests/cal_flops.py index 01b005af..ff0db0a2 100644 --- a/tests/cal_flops.py +++ b/tests/cal_flops.py @@ -12,7 +12,7 @@ from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore from llmtuner import ChatModel -def calculate( +def calculate_flops( model_name_or_path: str, batch_size: Optional[int] = 1, seq_length: Optional[int] = 256, @@ -41,4 +41,4 @@ def calculate( if __name__ == "__main__": - fire.Fire(calculate) + fire.Fire(calculate_flops) diff --git a/tests/cal_lr.py b/tests/cal_lr.py new file mode 100644 index 00000000..8a932d5e --- /dev/null +++ b/tests/cal_lr.py @@ -0,0 +1,63 @@ +# coding=utf-8 +# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. +# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 +# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py + +import fire +import math +import torch +from tqdm import tqdm +from torch.utils.data import DataLoader +from transformers import DataCollatorForSeq2Seq + +from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer + + +BASE_LR = 3e-4 +BASE_BS = 4_000_000 + + +def calculate_lr( + model_name_or_path: str, + dataset: str, + cutoff_len: int, + batch_size: int +): + model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( + stage="sft", + model_name_or_path=model_name_or_path, + dataset=dataset, + template="default", + cutoff_len=cutoff_len, + output_dir="dummy_dir", + fp16=True + )) + trainset = get_dataset(model_args, data_args) + _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") + trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft") + data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) + dataloader = DataLoader( + dataset=trainset, + batch_size=batch_size, + shuffle=True, + collate_fn=data_collator, + pin_memory=True + ) + valid_tokens, total_tokens = 0, 0 + for batch in tqdm(dataloader): + valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() + total_tokens += torch.numel(batch["labels"]) + + batch_max_len = cutoff_len * batch_size # max tokens in a batch + valid_ratio = valid_tokens / total_tokens + batch_valid_len = batch_max_len * valid_ratio + lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) + print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format( + lr, valid_ratio * 100, batch_valid_len + )) + + +if __name__ == "__main__": + fire.Fire(calculate_lr) diff --git a/tests/quantize.py b/tests/quantize.py index 25321cf3..7b529671 100644 --- a/tests/quantize.py +++ b/tests/quantize.py @@ -4,7 +4,6 @@ # --max_length 1024 --max_samples 1024 # dataset format: instruction (string), input (string), output (string), history (List[string]) - import fire from datasets import load_dataset from transformers import AutoTokenizer