update scripts
This commit is contained in:
parent
25aeaae51b
commit
c1a53a0deb
|
@ -4,6 +4,7 @@
|
|||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
|
|||
def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: str = "sft",
|
||||
stage: Literal["pt", "sft"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Sequence
|
||||
|
||||
import fire
|
||||
import torch
|
||||
|
@ -17,11 +18,37 @@ from llmtuner.hparams import get_train_args
|
|||
from llmtuner.model import load_model, load_tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
|
||||
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
|
||||
attention_mask = [1] * (prompt_len + answer_len)
|
||||
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
|
||||
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
|
||||
|
||||
return super().__call__(chosen_features)
|
||||
|
||||
|
||||
def cal_ppl(
|
||||
model_name_or_path: str,
|
||||
save_name: str,
|
||||
batch_size: int = 4,
|
||||
stage: str = "sft",
|
||||
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
|
@ -49,6 +76,10 @@ def cal_ppl(
|
|||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
elif stage == "rm":
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
Loading…
Reference in New Issue