Update cal_lr.py
This commit is contained in:
parent
fcb2daf7f3
commit
5619e76dc5
|
@ -15,8 +15,8 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||||
|
|
||||||
|
|
||||||
BASE_LR = 3e-4
|
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
|
||||||
BASE_BS = 4_000_000
|
BASE_BS = 4_000_000 # from llama paper
|
||||||
|
|
||||||
|
|
||||||
def calculate_lr(
|
def calculate_lr(
|
||||||
|
@ -54,7 +54,7 @@ def calculate_lr(
|
||||||
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
batch_max_len = cutoff_len * batch_size # max tokens in a batch
|
||||||
valid_ratio = valid_tokens / total_tokens
|
valid_ratio = valid_tokens / total_tokens
|
||||||
batch_valid_len = batch_max_len * valid_ratio
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
|
||||||
lr = lr / 6.0 if is_mistral else lr
|
lr = lr / 6.0 if is_mistral else lr
|
||||||
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
lr, valid_ratio * 100, batch_valid_len
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
|
|
Loading…
Reference in New Issue