Update cal_lr.py
This commit is contained in:
parent
42c8fc4fb9
commit
fcb2daf7f3
|
@ -22,8 +22,9 @@ BASE_BS = 4_000_000
|
||||||
def calculate_lr(
|
def calculate_lr(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
dataset: str,
|
dataset: str,
|
||||||
cutoff_len: int,
|
cutoff_len: int, # i.e. maximum input length during training
|
||||||
batch_size: int
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
|
is_mistral: bool # mistral model uses a smaller learning rate
|
||||||
):
|
):
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
|
@ -54,6 +55,7 @@ def calculate_lr(
|
||||||
valid_ratio = valid_tokens / total_tokens
|
valid_ratio = valid_tokens / total_tokens
|
||||||
batch_valid_len = batch_max_len * valid_ratio
|
batch_valid_len = batch_max_len * valid_ratio
|
||||||
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)
|
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS)
|
||||||
|
lr = lr / 6.0 if is_mistral else lr
|
||||||
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
print("Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
|
||||||
lr, valid_ratio * 100, batch_valid_len
|
lr, valid_ratio * 100, batch_valid_len
|
||||||
))
|
))
|
||||||
|
|
Loading…
Reference in New Issue