diff --git a/tests/cal_lr.py b/tests/cal_lr.py index 8a932d5e..2a85fa0e 100644 --- a/tests/cal_lr.py +++ b/tests/cal_lr.py @@ -22,8 +22,9 @@ BASE_BS = 4_000_000 def calculate_lr( model_name_or_path: str, dataset: str, - cutoff_len: int, - batch_size: int + cutoff_len: int, # i.e. maximum input length during training + batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) + is_mistral: bool # mistral model uses a smaller learning rate ): model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( stage="sft", @@ -54,6 +55,7 @@ def calculate_lr( valid_ratio = valid_tokens / total_tokens batch_valid_len = batch_max_len * valid_ratio lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) + lr = lr / 6.0 if is_mistral else lr print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format( lr, valid_ratio * 100, batch_valid_len ))