update scripts
This commit is contained in:
parent
25aeaae51b
commit
c1a53a0deb
|
@ -4,6 +4,7 @@
|
||||||
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
|
||||||
def calculate_lr(
|
def calculate_lr(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||||
stage: str = "sft",
|
stage: Literal["pt", "sft"] = "sft",
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Dict
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Literal, Sequence
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
@ -17,11 +18,37 @@ from llmtuner.hparams import get_train_args
|
||||||
from llmtuner.model import load_model, load_tokenizer
|
from llmtuner.model import load_model, load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for pairwise data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_on_prompt: bool = False
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
|
the last n examples represent rejected examples.
|
||||||
|
"""
|
||||||
|
chosen_features = []
|
||||||
|
for feature in features:
|
||||||
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature["chosen_ids"])
|
||||||
|
input_ids = feature["prompt_ids"] + feature["chosen_ids"]
|
||||||
|
attention_mask = [1] * (prompt_len + answer_len)
|
||||||
|
labels = input_ids if self.train_on_prompt else [IGNORE_INDEX] * prompt_len + feature["chosen_ids"]
|
||||||
|
chosen_features.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})
|
||||||
|
|
||||||
|
return super().__call__(chosen_features)
|
||||||
|
|
||||||
|
|
||||||
def cal_ppl(
|
def cal_ppl(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_name: str,
|
save_name: str,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
stage: str = "sft",
|
stage: Literal["pt", "sft", "rm"] = "sft",
|
||||||
dataset: str = "alpaca_en",
|
dataset: str = "alpaca_en",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
|
@ -49,6 +76,10 @@ def cal_ppl(
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||||
|
elif stage == "rm":
|
||||||
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue