Update loader.py

This commit is contained in:
hoshi-hiyouga 2024-04-26 03:33:07 +08:00 committed by GitHub
parent e16f128dc3
commit 3408af236f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -16,12 +16,13 @@ from .utils import checksum, merge_dataset
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from transformers import AutoProcessor, Seq2SeqTrainingArguments
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr
logger = get_logger(__name__)
@ -114,12 +115,12 @@ def load_single_dataset(
def get_dataset(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"],
processor: Optional["AutoProcessor"] = None,
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos:
@ -149,7 +150,7 @@ def get_dataset(
with training_args.main_process_first(desc="pre-process dataset"):
preprocess_func, print_function = get_preprocess_and_print_func(
tokenizer, template, data_args, training_args, stage, processor
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
kwargs = {}