Update preprocess.py

This commit is contained in:
hoshi-hiyouga 2024-07-15 00:55:36 +08:00 committed by GitHub
parent 84e4047f8a
commit df52fb05b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 3 deletions

View File

@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers import PreTrainedTokenizer, ProcessorMixin
from ..hparams import DataArguments
from .template import Template
@ -35,11 +35,11 @@ if TYPE_CHECKING:
def get_preprocess_and_print_func(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template",
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
do_generate: bool = False,
) -> Tuple[Callable, Callable]:
if stage == "pt":
preprocess_func = partial(
@ -48,7 +48,7 @@ def get_preprocess_and_print_func(
data_args=data_args,
)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate:
elif stage == "sft" and not do_generate:
if data_args.packing:
if data_args.neat_packing:
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence