Merge pull request #4691 from codemayq/feature-suppot-eval-dataset
add eval dataset support
This commit is contained in:
commit
15b399a82f
|
@ -11,8 +11,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
||||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||||
"subset": "the name of the subset. (optional, default: None)",
|
"subset": "the name of the subset. (optional, default: None)",
|
||||||
|
"split": "the name of dataset split to be used. (optional, default: train)",
|
||||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
"num_samples": "the number of samples in the dataset to be used. (optional, default: None)",
|
||||||
"columns (optional)": {
|
"columns (optional)": {
|
||||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
"subset": "数据集子集的名称(可选,默认:None)",
|
"subset": "数据集子集的名称(可选,默认:None)",
|
||||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||||
|
"split": "数据集中的要使用的训练测试集切分(可选,默认:train)",
|
||||||
"columns(可选)": {
|
"columns(可选)": {
|
||||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||||
"query": "数据集代表请求的表头名称(默认:input)",
|
"query": "数据集代表请求的表头名称(默认:input)",
|
||||||
|
|
|
@ -172,9 +172,19 @@
|
||||||
"deepctrl": {
|
"deepctrl": {
|
||||||
"ms_hub_url": "deepctrl/deepctrl-sft-data"
|
"ms_hub_url": "deepctrl/deepctrl-sft-data"
|
||||||
},
|
},
|
||||||
"adgen": {
|
"adgen_train": {
|
||||||
"hf_hub_url": "HasturOfficial/adgen",
|
"hf_hub_url": "HasturOfficial/adgen",
|
||||||
"ms_hub_url": "AI-ModelScope/adgen",
|
"ms_hub_url": "AI-ModelScope/adgen",
|
||||||
|
"split": "train",
|
||||||
|
"columns": {
|
||||||
|
"prompt": "content",
|
||||||
|
"response": "summary"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"adgen_val": {
|
||||||
|
"hf_hub_url": "HasturOfficial/adgen",
|
||||||
|
"ms_hub_url": "AI-ModelScope/adgen",
|
||||||
|
"split": "validation",
|
||||||
"columns": {
|
"columns": {
|
||||||
"prompt": "content",
|
"prompt": "content",
|
||||||
"response": "summary"
|
"response": "summary"
|
||||||
|
|
|
@ -65,7 +65,7 @@ def calculate_lr(
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
elif stage == "sft":
|
elif stage == "sft":
|
||||||
|
@ -73,7 +73,7 @@ def calculate_lr(
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(dataset_module["eval_dataset"], batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
valid_tokens, total_tokens = 0, 0
|
valid_tokens, total_tokens = 0, 0
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader):
|
||||||
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
|
||||||
|
|
|
@ -87,7 +87,7 @@ def cal_ppl(
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
@ -100,7 +100,7 @@ def cal_ppl(
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||||
|
|
||||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
dataloader = DataLoader(dataset_module["eval_dataset"], batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
total_ppl = 0
|
total_ppl = 0
|
||||||
perplexities = []
|
perplexities = []
|
||||||
|
|
|
@ -47,10 +47,10 @@ def length_cdf(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
total_num = len(trainset)
|
total_num = len(dataset_module["eval_dataset"])
|
||||||
length_dict = defaultdict(int)
|
length_dict = defaultdict(int)
|
||||||
for sample in tqdm(trainset["input_ids"]):
|
for sample in tqdm(dataset_module["eval_dataset"]["input_ids"]):
|
||||||
length_dict[len(sample) // interval * interval] += 1
|
length_dict[len(sample) // interval * interval] += 1
|
||||||
|
|
||||||
length_tuples = list(length_dict.items())
|
length_tuples = list(length_dict.items())
|
||||||
|
|
|
@ -13,16 +13,15 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets
|
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments
|
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
@ -42,24 +41,29 @@ class Role(str, Enum):
|
||||||
OBSERVATION = "observation"
|
OBSERVATION = "observation"
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetModule(TypedDict):
|
||||||
|
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||||
|
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||||
|
|
||||||
|
|
||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||||
data_args: "DataArguments",
|
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
if len(all_datasets) == 1:
|
if len(all_datasets) == 1:
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
|
||||||
return concatenate_datasets(all_datasets)
|
return concatenate_datasets(all_datasets)
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
probabilities=data_args.interleave_probs,
|
probabilities=data_args.interleave_probs,
|
||||||
seed=training_args.seed,
|
seed=seed,
|
||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -67,22 +71,17 @@ def merge_dataset(
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||||
) -> Dict[str, "Dataset"]:
|
) -> "DatasetDict":
|
||||||
if training_args.do_train:
|
r"""
|
||||||
if data_args.val_size > 1e-6: # Split the dataset
|
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
||||||
if data_args.streaming:
|
"""
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
if data_args.streaming:
|
||||||
val_set = dataset.take(int(data_args.val_size))
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
train_set = dataset.skip(int(data_args.val_size))
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
else:
|
return DatasetDict({"train": train_set, "validation": val_set})
|
||||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
else:
|
||||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||||
else:
|
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
|
||||||
return {"train_dataset": dataset}
|
|
||||||
else: # do_eval or do_predict
|
|
||||||
return {"eval_dataset": dataset}
|
|
||||||
|
|
|
@ -12,19 +12,19 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.misc import has_tokenized_data
|
from ..extras.misc import has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
from .parser import get_dataset_list
|
from .parser import get_dataset_list
|
||||||
from .preprocess import get_preprocess_and_print_func
|
from .preprocess import get_preprocess_and_print_func
|
||||||
from .template import get_template_and_fix_tokenizer
|
from .template import get_template_and_fix_tokenizer
|
||||||
|
@ -35,13 +35,15 @@ if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
|
||||||
from ..hparams import DataArguments, ModelArguments
|
from ..hparams import DataArguments, ModelArguments
|
||||||
|
from .data_utils import DatasetModule
|
||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_single_dataset(
|
def _load_single_dataset(
|
||||||
dataset_attr: "DatasetAttr",
|
dataset_attr: "DatasetAttr",
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -81,41 +83,34 @@ def load_single_dataset(
|
||||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||||
|
|
||||||
if dataset_attr.load_from == "ms_hub":
|
if dataset_attr.load_from == "ms_hub":
|
||||||
try:
|
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset
|
||||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||||
|
|
||||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||||
dataset = MsDataset.load(
|
dataset = MsDataset.load(
|
||||||
dataset_name=data_path,
|
dataset_name=data_path,
|
||||||
subset_name=data_name,
|
subset_name=data_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
token=model_args.ms_hub_token,
|
token=model_args.ms_hub_token,
|
||||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
)
|
)
|
||||||
if isinstance(dataset, MsDataset):
|
if isinstance(dataset, MsDataset):
|
||||||
dataset = dataset.to_hf_dataset()
|
dataset = dataset.to_hf_dataset()
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
|
||||||
else:
|
else:
|
||||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
|
||||||
kwargs = {"trust_remote_code": True}
|
|
||||||
else:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=data_args.split,
|
split=dataset_attr.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
**kwargs,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
|
@ -140,6 +135,66 @@ def load_single_dataset(
|
||||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_merged_dataset(
|
||||||
|
dataset_names: Optional[Sequence[str]],
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
if dataset_names is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
||||||
|
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||||
|
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||||
|
|
||||||
|
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||||
|
|
||||||
|
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_preprocessed_dataset(
|
||||||
|
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||||
|
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||||
|
)
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
kwargs = {}
|
||||||
|
if not data_args.streaming:
|
||||||
|
kwargs = dict(
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||||
|
desc="Running tokenizer on dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
try:
|
||||||
|
print("eval example:" if is_eval else "training example:")
|
||||||
|
print_function(next(iter(dataset)))
|
||||||
|
except StopIteration:
|
||||||
|
if stage == "pt":
|
||||||
|
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -147,7 +202,7 @@ def get_dataset(
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> "DatasetModule":
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||||
if data_args.train_on_prompt and template.efficient_eos:
|
if data_args.train_on_prompt and template.efficient_eos:
|
||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
@ -156,55 +211,66 @@ def get_dataset(
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
|
if "train" in dataset_dict:
|
||||||
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
|
if "validation" in dataset_dict:
|
||||||
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.to_iterable_dataset()
|
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||||
return dataset
|
|
||||||
|
return dataset_module
|
||||||
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||||
|
|
||||||
|
# Load and preprocess dataset
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||||
for dataset_attr in get_dataset_list(data_args):
|
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
||||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
|
||||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
|
||||||
|
|
||||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
|
||||||
|
|
||||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="pre-process dataset"):
|
with training_args.main_process_first(desc="pre-process dataset"):
|
||||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
dataset = _get_preprocessed_dataset(
|
||||||
data_args, training_args, stage, template, tokenizer, processor
|
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||||
|
)
|
||||||
|
eval_dataset = _get_preprocessed_dataset(
|
||||||
|
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||||
)
|
)
|
||||||
column_names = list(next(iter(dataset)).keys())
|
|
||||||
kwargs = {}
|
|
||||||
if not data_args.streaming:
|
|
||||||
kwargs = dict(
|
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
|
||||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
|
||||||
desc="Running tokenizer on dataset",
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
if data_args.val_size > 1e-6:
|
||||||
|
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
||||||
|
else:
|
||||||
|
dataset_dict = {}
|
||||||
|
if dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["train"] = dataset
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
if data_args.streaming:
|
||||||
|
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
|
||||||
|
dataset_dict["validation"] = eval_dataset
|
||||||
|
|
||||||
|
dataset_dict = DatasetDict(dataset_dict)
|
||||||
|
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if training_args.should_log:
|
dataset_module = {}
|
||||||
try:
|
if "train" in dataset_dict:
|
||||||
print_function(next(iter(dataset)))
|
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||||
except StopIteration:
|
if "validation" in dataset_dict:
|
||||||
if stage == "pt":
|
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
|
||||||
|
|
||||||
return dataset
|
return dataset_module
|
||||||
|
|
|
@ -15,16 +15,14 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from ..extras.constants import DATA_CONFIG
|
from ..extras.constants import DATA_CONFIG
|
||||||
from ..extras.misc import use_modelscope
|
from ..extras.misc import use_modelscope
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ..hparams import DataArguments
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetAttr:
|
class DatasetAttr:
|
||||||
r"""
|
r"""
|
||||||
|
@ -38,6 +36,7 @@ class DatasetAttr:
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
# extra configs
|
# extra configs
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
|
split: str = "train"
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
num_samples: Optional[int] = None
|
num_samples: Optional[int] = None
|
||||||
# common columns
|
# common columns
|
||||||
|
@ -71,31 +70,33 @@ class DatasetAttr:
|
||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||||
if data_args.dataset is not None:
|
r"""
|
||||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
Gets the attributes of the datasets.
|
||||||
else:
|
"""
|
||||||
|
if dataset_names is None:
|
||||||
dataset_names = []
|
dataset_names = []
|
||||||
|
|
||||||
if data_args.dataset_dir == "ONLINE":
|
if dataset_dir == "ONLINE":
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
else:
|
else:
|
||||||
|
if dataset_dir.startswith("REMOTE:"):
|
||||||
|
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||||
|
else:
|
||||||
|
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
with open(config_path, "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if len(dataset_names) != 0:
|
if len(dataset_names) != 0:
|
||||||
raise ValueError(
|
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
|
||||||
)
|
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
if data_args.interleave_probs is not None:
|
dataset_list: List["DatasetAttr"] = []
|
||||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
|
||||||
|
|
||||||
dataset_list: List[DatasetAttr] = []
|
|
||||||
for name in dataset_names:
|
for name in dataset_names:
|
||||||
if dataset_info is None:
|
if dataset_info is None: # dataset_dir is ONLINE
|
||||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||||
dataset_list.append(dataset_attr)
|
dataset_list.append(dataset_attr)
|
||||||
|
@ -120,6 +121,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
dataset_attr.set_attr("subset", dataset_info[name])
|
dataset_attr.set_attr("subset", dataset_info[name])
|
||||||
|
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
@ -35,11 +35,11 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
|
do_generate: bool = False,
|
||||||
) -> Tuple[Callable, Callable]:
|
) -> Tuple[Callable, Callable]:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
|
@ -48,7 +48,7 @@ def get_preprocess_and_print_func(
|
||||||
data_args=data_args,
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not do_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
if data_args.neat_packing:
|
if data_args.neat_packing:
|
||||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||||
|
|
|
@ -31,7 +31,11 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||||
|
)
|
||||||
|
eval_dataset: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
dataset_dir: str = field(
|
dataset_dir: str = field(
|
||||||
default="data",
|
default="data",
|
||||||
|
@ -107,6 +111,31 @@ class DataArguments:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
def split_arg(arg):
|
||||||
|
if isinstance(arg, str):
|
||||||
|
return [item.strip() for item in arg.split(",")]
|
||||||
|
return arg
|
||||||
|
|
||||||
|
self.dataset = split_arg(self.dataset)
|
||||||
|
self.eval_dataset = split_arg(self.eval_dataset)
|
||||||
|
|
||||||
|
if self.dataset is None and self.val_size > 1e-6:
|
||||||
|
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||||
|
|
||||||
|
if self.eval_dataset is not None and self.val_size > 1e-6:
|
||||||
|
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||||
|
|
||||||
|
if self.interleave_probs is not None:
|
||||||
|
if self.mix_strategy == "concat":
|
||||||
|
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
|
||||||
|
|
||||||
|
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
|
||||||
|
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
|
||||||
|
raise ValueError("The length of dataset and interleave probs should be identical.")
|
||||||
|
|
||||||
|
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
|
||||||
|
raise ValueError("The length of eval dataset and interleave probs should be identical.")
|
||||||
|
|
||||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||||
raise ValueError("Streaming mode should have an integer val size.")
|
raise ValueError("Streaming mode should have an integer val size.")
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@ def run_dpo(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
data_collator = PairwiseDataCollatorWithPadding(
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
|
@ -71,7 +71,7 @@ def run_dpo(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**dataset_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -41,7 +41,7 @@ def run_kto(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
data_collator = KTODataCollatorWithPadding(
|
data_collator = KTODataCollatorWithPadding(
|
||||||
|
@ -68,7 +68,7 @@ def run_kto(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**dataset_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -43,7 +43,7 @@ def run_ppo(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
|
@ -63,7 +63,7 @@ def run_ppo(
|
||||||
model=model,
|
model=model,
|
||||||
reward_model=reward_model,
|
reward_model=reward_model,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
dataset=dataset,
|
dataset=dataset_module["train_dataset"],
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,7 +42,7 @@ def run_pt(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ def run_pt(
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**dataset_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|
|
@ -41,7 +41,7 @@ def run_rm(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ def run_rm(
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**dataset_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -81,7 +81,7 @@ def run_rm(
|
||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict")
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(predict_results)
|
trainer.save_predictions(predict_results)
|
||||||
|
|
|
@ -43,7 +43,7 @@ def run_sft(
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
|
@ -76,7 +76,7 @@ def run_sft(
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else compute_accuracy,
|
||||||
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
preprocess_logits_for_metrics=None if training_args.predict_with_generate else eval_logit_processor,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**split_dataset(dataset, data_args, training_args),
|
**dataset_module,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
|
@ -105,12 +105,12 @@ def run_sft(
|
||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(dataset, predict_results)
|
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
|
||||||
|
|
||||||
# Create model card
|
# Create model card
|
||||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
|
|
|
@ -47,7 +47,7 @@ def test_supervised(num_samples: int):
|
||||||
model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS)
|
model_args, data_args, training_args, _, _ = get_train_args(TRAIN_ARGS)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
dataset_module = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
|
|
||||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
|
|
||||||
|
@ -63,5 +63,5 @@ def test_supervised(num_samples: int):
|
||||||
{"role": "assistant", "content": original_data[index]["output"]},
|
{"role": "assistant", "content": original_data[index]["output"]},
|
||||||
]
|
]
|
||||||
templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
|
templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
decoded_result = tokenizer.decode(tokenized_data["input_ids"][index])
|
decoded_result = tokenizer.decode(dataset_module["train_dataset"]["input_ids"][index])
|
||||||
assert templated_result == decoded_result
|
assert templated_result == decoded_result
|
||||||
|
|
Loading…
Reference in New Issue