Update preprocess.py

This commit is contained in:
hoshi-hiyouga 2024-07-15 00:55:36 +08:00 committed by GitHub
parent 84e4047f8a
commit df52fb05b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 3 deletions

View File

@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template
@ -35,11 +35,11 @@ if TYPE_CHECKING:
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> Tuple[Callable, Callable]: ) -> Tuple[Callable, Callable]:
if stage == "pt": if stage == "pt":
preprocess_func = partial( preprocess_func = partial(
@ -48,7 +48,7 @@ def get_preprocess_and_print_func(
data_args=data_args, data_args=data_args,
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not do_generate:
if data_args.packing: if data_args.packing:
if data_args.neat_packing: if data_args.neat_packing:
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence