diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py index 7bf8839d..dd864162 100644 --- a/scripts/cal_lr.py +++ b/scripts/cal_lr.py @@ -4,6 +4,7 @@ # Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py import math +from typing import Literal import fire import torch @@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper def calculate_lr( model_name_or_path: str, batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) - stage: str = "sft", + stage: Literal["pt", "sft"] = "sft", dataset: str = "alpaca_en", dataset_dir: str = "data", template: str = "default", diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 06c2a43b..2e74c70a 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -3,7 +3,8 @@ # Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json import json -from typing import Dict +from dataclasses import dataclass +from typing import Any, Dict, Literal, Sequence import fire import torch @@ -17,11 +18,37 @@ from llmtuner.hparams import get_train_args from llmtuner.model import load_model, load_tokenizer +@dataclass +class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): + r""" + Data collator for pairwise data. + """ + + train_on_prompt: bool = False + + def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: + r""" + Pads batched data to the longest sequence in the batch. + + We generate 2 * n examples where the first n examples represent chosen examples and + the last n examples represent rejected examples. + """ + chosen_features = [] + for feature in features: + prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"]) + input_ids = feature["prompt_ids"] + feature["chosen_ids"] + attention_mask = [1] * (prompt_len + answer_len) + labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"] + chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}) + + return super().__call__(chosen_features) + + def cal_ppl( model_name_or_path: str, save_name: str, batch_size: int = 4, - stage: str = "sft", + stage: Literal["pt", "sft", "rm"] = "sft", dataset: str = "alpaca_en", dataset_dir: str = "data", template: str = "default", @@ -49,6 +76,10 @@ def cal_ppl( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) elif stage == "sft": data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) + elif stage == "rm": + data_collator = PairwiseDataCollatorWithPadding( + tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt + ) else: raise NotImplementedError