update scripts

This commit is contained in:
hiyouga 2024-05-04 23:05:17 +08:00
parent 25aeaae51b
commit c1a53a0deb
2 changed files with 35 additions and 3 deletions

View File

@ -4,6 +4,7 @@
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
from typing import Literal
import fire
import torch
@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: str = "sft",
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",

View File

@ -3,7 +3,8 @@
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
import json
from typing import Dict
from dataclasses import dataclass
from typing import Any, Dict, Literal, Sequence
import fire
import torch
@ -17,11 +18,37 @@ from llmtuner.hparams import get_train_args
from llmtuner.model import load_model, load_tokenizer
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
train_on_prompt: bool = False
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
chosen_features = []
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
attention_mask = [1] * (prompt_len + answer_len)
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
return super().__call__(chosen_features)
def cal_ppl(
model_name_or_path: str,
save_name: str,
batch_size: int = 4,
stage: str = "sft",
stage: Literal["pt", "sft", "rm"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",
@ -49,6 +76,10 @@ def cal_ppl(
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
elif stage == "rm":
data_collator = PairwiseDataCollatorWithPadding(
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
)
else:
raise NotImplementedError