update scripts

This commit is contained in:
hiyouga 2024-07-03 20:07:44 +08:00
parent 8845e94f91
commit 1e0c860c8c
2 changed files with 4 additions and 2 deletions

View File

@ -44,6 +44,7 @@ def calculate_lr(
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
packing: bool = False,
):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
@ -57,6 +58,7 @@ def calculate_lr(
dataset_dir=dataset_dir,
template=template,
cutoff_len=cutoff_len,
packing=packing,
output_dir="dummy_dir",
overwrite_cache=True,
)
@ -69,7 +71,7 @@ def calculate_lr(
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError
raise NotImplementedError("Stage does not supported: {}.".format(stage))
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
valid_tokens, total_tokens = 0, 0

View File

@ -98,7 +98,7 @@ def cal_ppl(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError
raise NotImplementedError("Stage does not supported: {}.".format(stage))
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
criterion = torch.nn.CrossEntropyLoss(reduction="none")